keras-rl/examples/dqn_cartpole.pyを読んだ
動機
Q学習でうまく解けない問題を、DQNでとけるのか試したくなった。まずはお手軽と噂のkeras-rlのdqn_cartpoleを読んでみた。 備忘録としてメモする。 深くは理解していない。
まずは動く環境を作る
環境
Anaconda Navigatorより下記をインストール
- tensorflow 1.10
- keras 2.2.2
下記を参考にopenai-gym
とkeras-rl
をインストール
すんなり動いた。
python examples/dqn_cartpole.py
ソースを読む
import
import numpy as np import gym from keras.models import Sequential from keras.layers import Dense, Activation, Flatten from keras.optimizers import Adam from rl.agents.dqn import DQNAgent from rl.policy import BoltzmannQPolicy from rl.memory import SequentialMemory
変数定義
ENV_NAME = 'CartPole-v0' # Get the environment and extract the number of actions. env = gym.make(ENV_NAME) np.random.seed(123) env.seed(123) nb_actions = env.action_space.n
cartpole-v0
の環境を使う。
env.action_space.n
については、cartpoleのアクション右へ移動
もしくは左へ移動
の2
つ
(Pdb) nb_actions 2
モデルの定義
# Next, we build a very simple model. model = Sequential() model.add(Flatten(input_shape=(1,) + env.observation_space.shape)) model.add(Dense(16)) model.add(Activation('relu')) model.add(Dense(16)) model.add(Activation('relu')) model.add(Dense(16)) model.add(Activation('relu')) model.add(Dense(nb_actions)) model.add(Activation('linear')) print(model.summary())
mode.summary()
は以下が出力される。
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten_1 (Flatten) (None, 4) 0 _________________________________________________________________ dense_1 (Dense) (None, 16) 80 _________________________________________________________________ activation_1 (Activation) (None, 16) 0 _________________________________________________________________ dense_2 (Dense) (None, 16) 272 _________________________________________________________________ activation_2 (Activation) (None, 16) 0 _________________________________________________________________ dense_3 (Dense) (None, 16) 272 _________________________________________________________________ activation_3 (Activation) (None, 16) 0 _________________________________________________________________ dense_4 (Dense) (None, 2) 34 _________________________________________________________________ activation_4 (Activation) (None, 2) 0 ================================================================= Total params: 658 Trainable params: 658 Non-trainable params: 0 _________________________________________________________________
このようなイメージのネットワークになる。
入力層 4(状態 x,x_dot,theta,theta_dot) `Flatten(input_shape=(1,) + env.observation_space.shape)` ↓ 全総結合 活性化関数 relu `Dense(16), Activation('relu')` 中間層 16 ↓ 全総結合 活性化関数 relu `Dense(16), Activation('relu')` 中間層 16 ↓ 全総結合 活性化関数 relu `Dense(16), Activation('relu')` 中間層 16 ↓ 全総結合 活性化関数 linear `Dense(nb_actions), Activation('linear')` 出力層 2
Flatten(input_shape=(1,) + env.observation_space.shape)
は入力のレイヤーで入力は4つ。
shapeを+
で連結できることを初めて知った。
(Pdb) (1,)+env.observation_space.shape (1, 4)
cartpole環境の、観測値(x,x_dot,theta,theta_dot)
を受けるための4つの入力となっている
- x : カートの位置
- x_dot : カートの速度
- theta : 棒の角度
- theta_dot : 棒の角速度
def step(self, action): # 略 self.state = (x,x_dot,theta,theta_dot) # 略 return np.array(self.state), reward, done, {}
DQNAgentの設定
memory = SequentialMemory(limit=50000, window_length=1) policy = BoltzmannQPolicy() dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10, target_model_update=1e-2, policy=policy) dqn.compile(Adam(lr=1e-3), metrics=['mae'])
SequentialMemory
は、各ステップ毎のアクション
、報酬
、終了状態
、観測値
を格納するためのメモリ。-
limit
=50,000は50,000回分を保存するリングバッファが内部で準備される。 -
window_length
はメモリを格納する最小単位
-
BoltzmannQPolicy
ボルツマン分布を利用したソフトマックス手法による方策
なるべくQ値の大きな行動が高確率で選ばれるようになる方策らしい。
q_values = q_values.astype('float64') nb_actions = q_values.shape[0] exp_values = np.exp(np.clip(q_values / self.tau, self.clip[0], self.clip[1])) probs = exp_values / np.sum(exp_values) action = np.random.choice(range(nb_actions), p=probs)
DQNAgent
は環境中のエージェントを表すnb_steps_warmup
はトレーニング開始までのウォームアップのステップ数target_model_update
はおそらくモデルの更新間隔
dqn.compile()
はkerasのモデルをコンパイルし、損失関数とオプティマイザを設定するAdam(lr=1e-3)
はkerasの最適化関数
metrics
はkerasの評価関数mae
はMean Absolute Error
トレーニングとテスト
dqn.fit(env, nb_steps=50000, visualize=True, verbose=2) # After training is done, we save the final weights. dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME), overwrite=True) # Finally, evaluate our algorithm for 5 episodes. dqn.test(env, nb_episodes=5, visualize=True)
dqn.fit()
でトレーニングを実施- ソースより抜粋
Trains the agent on the given environment. # Arguments env: (`Env` instance): Environment that the agent interacts with. See [Env](#env) for details. nb_steps (integer): Number of training steps to be performed. action_repetition (integer): Number of times the agent repeats the same action without observing the environment again. Setting this to a value > 1 can be useful if a single action only has a very small effect on the environment. callbacks (list of `keras.callbacks.Callback` or `rl.callbacks.Callback` instances): List of callbacks to apply during training. See [callbacks](/callbacks) for details. verbose (integer): 0 for no logging, 1 for interval logging (compare `log_interval`), 2 for episode logging visualize (boolean): If `True`, the environment is visualized during training. However, this is likely going to slow down training significantly and is thus intended to be a debugging instrument. nb_max_start_steps (integer): Number of maximum steps that the agent performs at the beginning of each episode using `start_step_policy`. Notice that this is an upper limit since the exact number of steps to be performed is sampled uniformly from [0, max_start_steps] at the beginning of each episode. start_step_policy (`lambda observation: action`): The policy to follow if `nb_max_start_steps` > 0. If set to `None`, a random action is performed. log_interval (integer): If `verbose` = 1, the number of steps that are considered to be an interval. nb_max_episode_steps (integer): Number of steps per episode that the agent performs before automatically resetting the environment. Set to `None` if each episode should run (potentially indefinitely) until the environment signals a terminal state. # Returns A `keras.callbacks.History` instance that recorded the entire training process.
dqn.save_weights()でモデル保存
- ソースより抜粋
Saves the weights of an agent as an HDF5 file. # Arguments filepath (str): The path to where the weights should be saved. overwrite (boolean): If `False` and `filepath` already exists, raises an error.
dqn.test()でモデルのテスト
- ソースより抜粋
Callback that is called before training begins. # Arguments env: (`Env` instance): Environment that the agent interacts with. See [Env](#env) for details. nb_episodes (integer): Number of episodes to perform. action_repetition (integer): Number of times the agent repeats the same action without observing the environment again. Setting this to a value > 1 can be useful if a single action only has a very small effect on the environment. callbacks (list of `keras.callbacks.Callback` or `rl.callbacks.Callback` instances): List of callbacks to apply during training. See [callbacks](/callbacks) for details. verbose (integer): 0 for no logging, 1 for interval logging (compare `log_interval`), 2 for episode logging visualize (boolean): If `True`, the environment is visualized during training. However, this is likely going to slow down training significantly and is thus intended to be a debugging instrument. nb_max_start_steps (integer): Number of maximum steps that the agent performs at the beginning of each episode using `start_step_policy`. Notice that this is an upper limit since the exact number of steps to be performed is sampled uniformly from [0, max_start_steps] at the beginning of each episode. start_step_policy (`lambda observation: action`): The policy to follow if `nb_max_start_steps` > 0. If set to `None`, a random action is performed. log_interval (integer): If `verbose` = 1, the number of steps that are considered to be an interval. nb_max_episode_steps (integer): Number of steps per episode that the agent performs before automatically resetting the environment. Set to `None` if each episode should run (potentially indefinitely) until the environment signals a terminal state. # Returns A `keras.callbacks.History` instance that recorded the entire training process.
所感
keras-rlだと、DQNも非常にシンプルに書ける。 わかったような気になれた。