atari - 模倣学習② 人間のデモを使って事前学習を行う

前回収集した人間のデモ操作データを使って事前学習を行います。

環境設定に関しては、前回の記事(模倣学習② 人間のデモを使って事前学習を行う)を参照して下さい。

(Ubuntu 19.10で動作確認しています。)

模倣学習

人間のデモ操作データであるbowling_demo.npzファイルrecorded_imagesフォルダを使って模倣学習を行うコードは以下の通りです。

[コード]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gym
import time
from stable_baselines import PPO2
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.gail import ExpertDataset, generate_expert_traj
from baselines.common.atari_wrappers import *

# 環境の生成
env = gym.make('BowlingNoFrameskip-v0')
env = MaxAndSkipEnv(env, skip=4) # 4フレームごとに行動を選択
env = WarpFrame(env) # 画面イメージを84x84のグレースケールに変換
env = DummyVecEnv([lambda: env])

# デモデータの読み込み
dataset = ExpertDataset(expert_path='bowling_demo.npz',verbose=1)

# モデルの生成
model = PPO2('CnnPolicy', env, verbose=1)

# モデルの読み込み
# model = PPO2.load('bowling_model', env=env)

# モデルの事前訓練
model.pretrain(dataset, n_epochs=1000)

# モデルの学習
# model.learn(total_timesteps=256000)

# モデルの保存
model.save('bowling_model')

# モデルのテスト
state = env.reset()
total_reward = 0
while True:
env.render() # 環境の描画
time.sleep(1/60) # スリープ
action, _ = model.predict(state) # モデルの推論
state, reward, done, info = env.step(action) # 1ステップ実行
total_reward += reward[0]
if done:
print('reward:', total_reward)
state = env.reset()
total_reward = 0

人間の操作したデータを事前学習するにはmodel.pretrain関数(24行目)を使います。
引数の意味は下記の通りです。

  • dataset(ExpertDataset型)
    データセット
  • n_epochs(int型)
    学習の反復回数
  • learning_rate(float型)
    学習率
  • adam_epsilon(float型)
    Adamオプティマイザーのε(エプシロン)
  • val_interval(int型)
    nエポック毎に学習と検証の損失を出力

また模倣学習を行った後、さらに強化学習を行う場合はmodel.learn関数(27行目)をコメントアウトします。

実行

実行すると、スコアは「120.2」になりました。(人間のデモ操作によって結果は変わります。)

模倣学習と合わせて強化学習も合わせて実行した結果や、強化学習のみで実行した結果も調査していきたいと思います。

atari - 模倣学習① 人間のデモ収集

ランダム行動では報酬を見つけにくい環境に対応するために模倣学習を試してみます。

Atari環境の1つであるボーリングゲーム(Bowling)を実行環境とします。

(Windowsではうまく動作しなかったので、Ubuntu 19.10で動作確認しています。)

インストール

下記のコマンドを実行し、実行環境をインストールします。

1
2
3
4
5
6
pip3 install gym
apt install cmake libopenmpi-dev python3-dev zlib1g-dev
pip3 install stable_baselines[mpi]
pip3 install tensorflow==1.14.0
pip3 install imageio
pip3 install baselines

人間のデモ収集

人間のデモ収集を行うコードは下記になります。

[コード]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import random
import pyglet
import gym
import time
from pyglet.window import key
from stable_baselines.gail import generate_expert_traj
from baselines.common.atari_wrappers import *

# 環境を作成
env = gym.make('BowlingNoFrameskip-v0')
env = MaxAndSkipEnv(env, skip=4) # 4フレームごとに行動を選択
env = WarpFrame(env) # 画面イメージを84x84のグレースケールに変換
env.render()

# キーイベント用のウィンドウ作成
win = pyglet.window.Window(width=300, height=100, vsync=False)
key_handler = pyglet.window.key.KeyStateHandler()
win.push_handlers(key_handler)
pyglet.app.platform_event_loop.start()

# キー状態の取得
def get_key_state():
key_state = set()
win.dispatch_events()
for key_code, pressed in key_handler.items():
if pressed:
key_state.add(key_code)
return key_state

# キー入力待ち
while len(get_key_state()) == 0:
time.sleep(1.0/30.0)

# 人間のデモを収集するコールバック
def human_expert(_state):
key_state = get_key_state() # キー状態の取得
action = 0 # 行動の選択

if key.SPACE in key_state:
action = 1
elif key.UP in key_state:
action = 2
elif key.DOWN in key_state:
action = 3

time.sleep(1.0/30.0) # スリープ
env.render() # 環境の描画
return action # 行動の選択

# 人間のデモの収集
generate_expert_traj(human_expert, 'bowling_demo', env, n_episodes=1)

デモ収集にはgenerate_expert_trajを使います。引数の意味は下記の通りです。

  • model(モデルまたはコールバック型)
    モデルまたはコールバック
  • save_path(str型)
    保存先のデモファイルのパス(拡張子なし)
  • env(gym.Env型)
    環境
  • n_timesteps(int型)
    モデルの学習ステップ数
  • n_episodes(int型)
    記録するエピソード数
  • image_folder(str型)
    画像を使用する場合の保存フォルダ

返値はデモ demo(dict型)となります。

実行

実行すると、次のような画面が表示されます。右側のウィンドウにフォーカスをあてるとゲームを操作することができます。

実行結果

updownで位置を選択し、fireでボールを投げます。
ボールを投げた後にupdownでボールの起動を曲げることができます。

10ゲーム(1エピソード)の人間の操作が収集され、bowling_demo.npzファイルrecorded_imagesフォルダが出力されます。

  • bowling_demo.npzファイル
    Pythonの辞書形式で保存されます。
    キーとしてactionsepisode_returnsrewardsobsepisode_startsがあり、obsには画像への相対パスが格納されます。
  • recorded_imagesフォルダ
    各状態の画像が保存されます。

次回は、今回収集した人間のデモデータを使って事前学習を行います。


Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×