ランダム行動では報酬を見つけにくい環境に対応するために模倣学習 を試してみます。
Atari環境の1つであるボーリングゲーム(Bowling) を実行環境とします。
(Windowsではうまく動作しなかったので、Ubuntu 19.10で動作確認しています。)
インストール
下記のコマンドを実行し、実行環境をインストールします。
1 2 3 4 5 6 pip3 install gym apt install cmake libopenmpi-dev python3-dev zlib1g-dev pip3 install stable_baselines[mpi] pip3 install tensorflow==1.14.0 pip3 install imageio pip3 install baselines
人間のデモ収集
人間のデモ収集を行うコードは下記になります。
[コード]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 import random import pyglet import gym import time from pyglet.window import key from stable_baselines.gail import generate_expert_traj from baselines.common.atari_wrappers import * # 環境を作成 env = gym.make('BowlingNoFrameskip-v0') env = MaxAndSkipEnv(env, skip=4) # 4フレームごとに行動を選択 env = WarpFrame(env) # 画面イメージを84x84のグレースケールに変換 env.render() # キーイベント用のウィンドウ作成 win = pyglet.window.Window(width=300, height=100, vsync=False) key_handler = pyglet.window.key.KeyStateHandler() win.push_handlers(key_handler) pyglet.app.platform_event_loop.start() # キー状態の取得 def get_key_state(): key_state = set() win.dispatch_events() for key_code, pressed in key_handler.items(): if pressed: key_state.add(key_code) return key_state # キー入力待ち while len(get_key_state()) == 0: time.sleep(1.0/30.0) # 人間のデモを収集するコールバック def human_expert(_state): key_state = get_key_state() # キー状態の取得 action = 0 # 行動の選択 if key.SPACE in key_state: action = 1 elif key.UP in key_state: action = 2 elif key.DOWN in key_state: action = 3 time.sleep(1.0/30.0) # スリープ env.render() # 環境の描画 return action # 行動の選択 # 人間のデモの収集 generate_expert_traj(human_expert, 'bowling_demo', env, n_episodes=1)
デモ収集にはgenerate_expert_traj を使います。引数の意味は下記の通りです。
model(モデルまたはコールバック型) モデルまたはコールバック
save_path(str型) 保存先のデモファイルのパス(拡張子なし)
env(gym.Env型) 環境
n_timesteps(int型) モデルの学習ステップ数
n_episodes(int型) 記録するエピソード数
image_folder(str型) 画像を使用する場合の保存フォルダ
返値はデモ demo(dict型) となります。
実行
実行すると、次のような画面が表示されます。右側のウィンドウにフォーカスをあてるとゲームを操作することができます。
up 、down で位置を選択し、fire でボールを投げます。 ボールを投げた後にup 、down でボールの起動を曲げることができます。
10ゲーム(1エピソード)の人間の操作が収集され、bowling_demo.npzファイル とrecorded_imagesフォルダ が出力されます。
bowling_demo.npzファイル Pythonの辞書形式で保存されます。 キーとしてactions 、episode_returns 、rewards 、obs 、episode_starts があり、obs には画像への相対パスが格納されます。
recorded_imagesフォルダ 各状態の画像が保存されます。
次回は、今回収集した人間のデモデータを使って事前学習を行います。