LSTM/GRUを使ったエージェント、agt_lstm, agt_gru の解説です。

LSTMを取り入れたネットワーク構造

それではqnet にどのようにLSTMを組み込むのが良いのでしょうか。

先行研究として、DQNのネットワークの出力層の手前の全結合層をLSTMに置き換えただけというモデル、DRQN (Deep Recurrent Q-Network)が提案されています(Hausknecht 2017)。

この論文では、Pong というボールを打ち合うゲームをDRQNに学習させ、画面を時々見えなくしてもそこそこプレイできることを示しています。

DRQNの論文での例題は、筆者が考える短期記憶を意識した実装ではないように思いますが、ここでのLSTMの使い方に習って、出力層の手前にLSTMを入れたニューラルネット agt_lstm.py と、代わりに GRU を入れたagt_gru.py を作りました。

両者のプログラムは、ほとんど同じですので、以下、agt_lstm.py を説明していきます。

リプレイメモリー

agt_lstm.py では、タスクを実行するフェイズと学習するフェイズを別々に交互に行います。

タスクを実行するフェイズでは、毎ステップ、前の観察x, 前の行動a, 報酬 r, 今の観察x を経験 e というセットにして、経験記憶クラス(ReplayMemory)に保存します。これは、前章で述べてきた短期記憶とは別物です。

学習するフェイズでは、経験記憶クラスのデータを使って、ネットワークのトレーニングをします。

トレーニングでは、この二つのフェイズを繰り返し行っていきます。なぜ、このようにフェイズを分けたかというと、LSTMの内部状態が、実行時の処理でも学習時の処理でも変化してしまい、同時に行うと互いに干渉してしまうからです。

DQNでもこのようなリプレイメモリーが使われていますが、若干の違いがあります。DQNではメモリーのサンプリングはランダムに行います。一方、ここでのリプレイメモリーは、時系列学習をさせることが目的なので、保存されているエピソードをランダムに選んだら、そのままの順番でまとまった経験をサンプリングするようにしています。

class ReplayMemory の具体的なコードです。

add(experience)で経験を追加し、sample(data_length)で、data_length分の経験を取り出します。memory_size は、蓄えるエピソードの数です。

パラメータ設定、__init__()

それでは、agt_lstm.py の中身の説明に入ります。

__init_()では、パラメータのセッティングをします。実際に使用するパラメータの値は、sim_swamptour.py に記述されていますので、ここではどんなものがあるかを見ていきます。

qnet から新しく入ったパラメータは、まず、n_lstm です。これはLSTMユニットの数です。

memory_sizeはReplayMemoryで保存するエピソード数、data_length_for_learnは、取り出す経験の数です。

実行と学習は交互に行うという説明をしましたが、learn_interval は、何ステップの実行毎に学習を行うかを決めるパラメータです。

epochs_for_train, batch_sizse, data_length は、学習時のデータのパラメータですが、今のところ全て1で使います。

モデル生成、build_model()

モデルの生成は、build_model()で行います。実質的なモデルの生成は、_build_model()で行っています。

LSTMを使っているために、(A)のTimeDistributed のラッパーを使用しています。バッチは1、時系列データの長さも1として(1batch_size=1, data_length=1)、1つの観察情報を1回ずつ入れていく想定です。

(B) で LSTMを定義しています。LSTMの内部状態 state_h, state_c も (C) の output に含めて参照できるようにしています(gruの場合には、内部状態は一つになります)。

行動選択、select_action()

行動選択は、qnet の時と同じepsilon-greedy法です。(A)でQ値を得て、(B)で最もQ値が大きい行動を出力します。そして、(C) epsilon の確率でランダム出力をします。

ただし、(A)の get_Q()の内部について少しコメントがあります。

この関数の中の(D)で、model.predict()を実行してQを得ていますが、このmodel.predict()を1回実行すると内部状態が変化しますので、同じ入力に対しても2回目を実行すると出力は異なってきます。

また、(D)の model.predict()のoutput は、build_model()で指定したように、LSTMの内部状態 hstate0, hstate1 も出力されます。gru の場合は、内部状態が1つなのでhstate のみとなっています。

Q学習、learn()

学習の部分です。このlearn()は、Class Trainer の中で毎ステップ実行されますが、毎回行われるのは経験の登録のみで、実際に学習が行われるのは、learn_interval で指定したstep 毎になります(A)。

学習に入る場合には、気を付けないといけないことがあります。それは、学習をすると、今までの内部状態が壊れてしまうということです。そこで、学習に入る前には、いったん(B)で内部状態を保存し、学習が終わったら(D)でもとの内部状態に戻します。

実際の学習は(C)で実行している _fit() (E)で行っています。学習のプロセスでも内部状態の扱いに注意が必要です。

まず、ターゲットデータを作ります。

はじめに、内部状態をリセットします(F)。学習をさせていない実行時では、エピソードの開始時に内部状態をリセットしているので、それと同じ状態にするという意味があります。

(G)で、data_length_for_learnの長さの観測 observationを model.predict() に順番に入力し、Qss にQ値をまとめます。このとき、エピソードが終了していたら(done==True) 内部状態をリセットするようにします。次のエピソード開始時のためです。

(H)で、Qssを使ってTarget を作成し、Tssにまとめます。同時に、観測データXss もまとめます。

(I)でモデルの内部状態をリセットしたあと、(J)で実際に学習を行います。データーXssとTssから1ステップ分ずつ取り出してmodel.fit()で学習させます。これをデータ分繰り返します。これも、エピソードが終了していたらその都度内部状態をリセットするようにします。

通常の教師あり学習では、データをまとめてmodel.fit()にセットし、一気に学習させることができるのですが、強化学習の場合は、特に、時系列データとして扱わなくてはならないLSTMの場合は、かなり面倒になりますね。

内部状態のリセット・保存・ロード、reset(), save_state(), load_state()

内部状態のリセット・保存・ロードは、学習の前後や学習のプロセスで使われていました。その具体的なコードです。

保存と読み出し、save_weights(), load_weights()

保存と読み出しは、qnetと変わりません。

以上で、agt_lstm.py の内容は全てです。

agt_lstm, agt_gru で出来るタスクとできないタスク

agt_lstm, agt_gru は、qnetが解くことのできる silent_ruin, open_field, many_swamp を解くことができます。ruin_1swampもqnetと同じように大体できます。

そして、qnetができなかった、短記憶を必要とする Tmaze_both, Tmaze_either を解くことができます。これが、lstm/gru の成果です。

Tmaze_both です。

Tmaze_either です。

ただし、トレーニングは学習が達成していなくても 5000 step 終了になりますが、学習がいつできるかどうかは確率的です。追加学習のコマンドM で3 回ほど試し、それでもできなかったら、Lで初めからやり直してみてください。筆者は、LSTM, GRU のどちらでもTmaze_both, Tmaze_either ができることを確認しています。

Tmaze_both, Tmaze_either は短期記憶を必要としますがかなりシンプルなタスクです。それにも関わらず、学習は簡単ではありません

そして、Tmaze よりも複雑になったruin_2swamp は満足に解くことはできません。30000回の学習でも以下のようなパフォーマンスでした。1度行ったゴールにもまた近づいてしまっていることが分かります。2つのゴールを訪れることができたエピソードでも、1度行ったゴールを通り過ぎて2つめのゴールに向かっているので、1度行ったゴールを覚えているわけではなさそうです。

筆者が試した限り、ネットワークのユニット数や層を変えてもパフォーマンスは上がりませんでした。このような問題がなぜ、lstmで出来ないのか、どうすればできるのかを考えることが今後の課題ですね。

[top] [01] [02] [03] [04] [05] [06] [07] [08] [09] 次は[10]