【what is this】引き続き、「スマホでニューラルネットワーク」にこだわります。以前の記事 [1]では、Q-Learning(強化学習)をニューラルネットワークで行うための基本手法を検討し、実装しました。今回は、さらにそれを深めた手法である「Experience Replay + Target Networks」があることを知り、それをTensorflow.js (JavaScript環境)で構築し、妥当なQ値が得られるまでPCで学習させました。さらに、その学習済みモデルをスマホへ格納し、スマホのアプリとして、Q-Learningを用いた2Dグリッドの経路探索課題を解きました。
■ ExperienceReplay+TargetNetworksによるQ-Learningスマホアプリ
この手法によるQ-Learningの結果を先に示します。課題は参考文献[1]で扱ってきた、2Dグリッドでの経路探索です。Fig.3にあるように、ロボットが壁や障害物を避けて宝石(緑色の球体)に最短距離で到達するように学習させます。その学習結果を、スマホアプリで示したのがFig.3です。この例では、ロボットが障害物(黒色の正方形)を避けてうまく宝石に到達するルートを学習したことが分かります。
■ ExperienceReplay+TargetNetworksの概要
強化学習は一般に、「教師付き学習ではない」と言われています。強化学習では、固定した正確なラベル(目標)を設定すること自体が困難です(というよりも、それ自体が求める解なのですから)。しかし、刻々変化するラベルを対象として、ニューラルネットワークで「教師付き学習」させる方式があります。それが、文献[1]で検討したものです。
しかし、この方法には、2つほど欠点があります。一つは、「ある状態に対してある行動を取った場合」の一組づつしか学習できない。つまり、ニューラルネットワークで本来の性能を発揮するためのミニバッチ処理(多数の入力の一括処理)ができないことです。第二は、もっと本質的な問題ですが、この方法では、「状態と行動」が強く結合した学習となってしまい、多様な入力(状態)に対する学習が収束しずらいか、振動してしまう可能性が高いことです。
これを解決すべく登場したのが、今回の「Experience Replay + Target Networks」なのです。小生の場合、参考文献[2]と[3]を読んでその概要を学びました。Fig.1に示すPseudo codingは文献[2]から引用したものです。これに従って、独自にJavaScript(Tensorflow.js使用)でそれを試作することができました。
大雑把に言うならば、上図において、
ExperienceReplayは、ミニバッチによる教師付き学習を可能にするために、replay memory Dに「状態と行動に関する観測結果」を蓄積していきます。そして、
TargetNetworksの方は、上に述べた「状態と行動の関連性」を解消すべく、学習用のニューラルネットワークとは別に用意された、予測(prediction)用のニューラルネットワークです。すなわち、2つの分離されたネットワークを使います。TargetNetworksは、定期的に、学習用のネットワークの重みで置き換えられて、新しくなって行きます。このため、学習と予測のネットワークに時間差が生じますが、これが実は求めるべきQ値の推定を安定させることに繋がる。そのように考えられます。
■ Tensorflow(Python)よりもTensorflow.js(JavaScript)を使った理由
この手法の実装における、「出力値とラベルの差(誤差)」の最小化を図るには、やはりTensorflowの学習関数(fit)を使うのが便利です。Pythonでやってももちろん良いのですが、今回は、JavaScript上のTensorflow.jsを使いました。その理由は、スマホアプリとの相性が良いことによります。
ただし、学習済みモデルは、スマホ(あるいは外部の)webサーバに配置する必要があります。スマホ用のwebサーバーはいくつも公開されていますので、手軽に使えます。Fig.2はその一例です。この例では、学習済みモデルは2つのファイル(ネットワークトポロジー等の.jsonと重みの.bin)が、スマホのwebサーバに配置されています。スマホのアプリ側では、このモデルを(JavaScriptプログラムで)ロードして、予測に使うことができます。
■ 留意点:ExperienceReplay+TargetNetworksをTensorflow.jsで行う場合
(1)Tensorflow.jsは非同期関数
Tensorflow.jsは非同期関数の仕様になっているので、その使用は、非同期関数(先頭にasync付き)の中で行う必要があります。そして、学習用のfit関数等を呼び出す場合は、awaitによって、fit関数の実行終了を待つ必要があります。そうしないと、思わぬところで別のコード部が実行されたりしますので、注意が必要です。
(2)モデルのsaveとloadの方法
上に述べたように、学習用のネットワークtrain_netの重みを、定期的にTatget Networksへloadする必要があります。その際に、Tensorflow.jsは、非常に使いやすいsave/loadの仕組みを提供しています。つまり、以下のように、ブラウザのメモリにsaveできる、localstorageスキームを使うことができます。
// target_netのモデルを更新するため、train_netのモデルをsave
await train_net.save('localstorage://2dgrid_model');
// train_netのmodelをtarget_netへロード
target_net = await tf.loadLayersModel('localstorage://2dgrid_model');
一方、これとは別に、train_netの学習済みモデルを外部へ取り出して利用したい場合には、downloadsスキームによって、save/loadを行うことができます。以下はその例です。
// 学習済みモデルのダウンロード
await train_net.save('downloads://2dgrid_model');
// 学習済みモデルをmodelフォルダに配置してそこからload
train_net = await tf.loadLayersModel('./model/2dgrid_model.json');
なお、loadした学習モデルを単に予測に使うだけなら、これOKですが、さらにそれを学習させる場合は、以下のように、再度、最適化関数を同じものに設定してコンパイルする必要があります。
train_net.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
(3)学習関数fitへ与える入力データの形式
PythonでのTensorflowでは、fit関数の引数となる訓練データとラベルデータは、pandas array形式ですが、Tensorflow.jsでは、通常配列をtensor2dで変換して与えます。
【注】上記の(1), (2)は、Tensorflow for JavaScriptの下位レベルAPIであるCore APIを使った場合です。そうではなく、上位レベルAPIであるLayers APIを使う場合には、非同期性を考慮しなくても使える関数が用意されています。
■ ニューラルネットワークの構成と学習性能
上に述べた2つのニューラルネットワーク(学習用と予測用)の構成は以下のようにしました。試行錯誤の結果、これに落ち着きました。隠れ層は1層よりも2層の方が良い。また、隠れ層1層目のノード数はやや多い方が良い(ここでは128とした)などが分かりました。活性化関数は、隠れ層ではreluに、出力層ではlinearに設定しました。
_______________________________________________________
Layer (type) Output shape Param #
=================================================
dense_Dense1 (Dense) relu [null, 128] 512
_______________________________________________________
dense_Dense2 (Dense) relu [null, 32] 4128
_______________________________________________________
dense_Dense3 (Dense) linear [null, 4] 132
================================================
Total params: 4772
Trainable params: 4772
Non-trainable params: 0
_______________________________________________________
(各隠れ層の後にdropout層を挿入した場合も試しましたが、この問題に関しては、特段の効果は確認できませんでした。)
Fig.3に示した4x4グリッドについて、学習したモデルを次回の学習の初期モデルにしてさらに訓練を続けるというやり方で(ε-Greedyのεの値はその度に一定比率で減少させて)学習させました。その結果を使って、状態(宝石の位置、障害物の位置、ロボットの位置)をランダムに100組み生成して、ロボットが宝石に最短距離で到達できるか否かをテストしました。その成功率は以下の通りでした。
・1回目:0.97
・2回目:0.95
・3回目:0.97
■ 感想
「不正確なラベル」を使って「近似的な解法」を実行して行くという、この(一見確信を持てないような)手法ですが、実際にやってみると段々にラベルが正確な値に近づき、正しい答えを出すようになるのは不思議な気もします..
この手法で学習させた結果は、上記のとおり、かなり満足のできるものでした。従来のベルマン方程式に基づくQ-tableを構築しながら学習させる方法では、ほぼ確実に厳密解に到達できました。これに対して、今回のニューラルネットワークを用いた上記のExperience Replay + Target Networksでも、95%程度の正解率を得ることができました。そして、問題規模がさらに増大した場合は、Q-tableによる学習は明らかに破綻するので、今回の手法の有効性が高まるものと感じます。
この実装は、ほとんどFig.1の情報だけから出発して(参考文献[2]や[3]に載っていたPytonコードを見ずに)、自分で考えながら具体化し、Tensorflow.jsを使って実現しました。どこか間違っているかも知れないという不安は残っていたのですが、上記のとおり95%の正解率を得ることができたので、恐らく、妥当な作りになっているだろうと思います。
参考文献
[1] Running Q-Learning on your palm (Part3)
[2] Jordi TORRES.AI, Deep Q-Network (DQN)-II
Experience Replay and Target Networks, Aug.16, 2020
[3] Ketan Doshi, Reinforcement Learning Explained Visually (Part 5): Deep Q Networks, step-by-step
A Gentle Guide to DQNs with Experience Replay, in Plain English, Dec 20, 2020
0 件のコメント:
コメントを投稿