Kaggle - ConnectX(4) - 4枚そろえるボードゲーム

Connect Xコンペに関する4回目の記事です。

Connect X

機械学習ではない、ルールベースのエージェントがありましたので試しに実行してみます。

ルールベースのエージェント

ConnectX Rule-Based

エージェントの実装だけ抽出してみます。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def my_agent(obs, conf):
def get_results(x, y, mark, multiplier):
""" get list of points, lowest cells and "in air" cells of a board[x][y] cell considering mark """
# set board[x][y] as mark
board[x][y] = mark
results = []
# if some points in axis already found - axis blocked
blocked = [False, False, False, False]
# i is amount of marks required to add points
for i in range(conf.inarow, 2, -1):
# points
p = 0
# lowest cell
lc = 0
# "in air" points
ap = 0
# axis S -> N, only if one mark required for victory
if i == conf.inarow and blocked[0] is False:
(p, lc, ap, blocked[0]) = process_results(p, lc, ap,
check_axis(mark, i, x, lambda z : z, y + inarow_m1, lambda z : z - 1))
# axis SW -> NE
if blocked[1] is False:
(p, lc, ap, blocked[1]) = process_results(p, lc, ap,
check_axis(mark, i, x - inarow_m1, lambda z : z + 1, y + inarow_m1, lambda z : z - 1))
# axis E -> W
if blocked[2] is False:
(p, lc, ap, blocked[2]) = process_results(p, lc, ap,
check_axis(mark, i, x + inarow_m1, lambda z : z - 1, y, lambda z : z))
# axis SE -> NW
if blocked[3] is False:
(p, lc, ap, blocked[3]) = process_results(p, lc, ap,
check_axis(mark, i, x + inarow_m1, lambda z : z - 1, y + inarow_m1, lambda z : z - 1))
results.append((p * multiplier, lc, ap))
# restore board[x][y] original value
board[x][y] = 0
return results

def check_axis(mark, inarow, x, x_fun, y, y_fun):
""" check axis (NE -> SW etc.) for lowest cell and amounts of points and "in air" cells """
(x, y, axis_max_range) = get_x_y_and_axis_max_range(x, x_fun, y, y_fun)
zeros_allowed = conf.inarow - inarow
#lowest_cell = y
# lowest_cell calculation turned off
lowest_cell = 0
for i in range(axis_max_range):
x_temp = x
y_temp = y
zeros_remained = zeros_allowed
marks = 0
# amount of empty cells that are "in air" (don't have board bottom or mark under them)
in_air = 0
for j in range(conf.inarow):
if board[x_temp][y_temp] != mark and board[x_temp][y_temp] != 0:
break
elif board[x_temp][y_temp] == mark:
marks += 1
# board[x_temp][y_temp] is 0
else:
zeros_remained -= 1
if (y_temp + 1) < conf.rows and board[x_temp][y_temp + 1] == 0:
in_air -= 1
# if y_temp > lowest_cell:
# lowest_cell = y_temp
if marks == inarow and zeros_remained == 0:
return (sp, lowest_cell, in_air, True)
x_temp = x_fun(x_temp)
y_temp = y_fun(y_temp)
if y_temp < 0 or y_temp >= conf.rows or x_temp < 0 or x_temp >= conf.columns:
return (0, 0, 0, False)
x = x_fun(x)
y = y_fun(y)
return (0, 0, 0, False)

def get_x_y_and_axis_max_range(x, x_fun, y, y_fun):
""" set x and y inside board boundaries and get max range of axis """
axis_max_range = conf.inarow
while y < 0 or y >= conf.rows or x < 0 or x >= conf.columns:
x = x_fun(x)
y = y_fun(y)
axis_max_range -= 1
return (x, y, axis_max_range)

def process_results(p, lc, ap, axis_check_results):
""" process results of check_axis function, return lowest cell and sums of points and "in air" cells """
(points, lowest_cell, in_air, blocked) = axis_check_results
if points > 0:
if lc < lowest_cell:
lc = lowest_cell
ap += in_air
p += points
return (p, lc, ap, blocked)

def get_best_cell(best_cell, current_cell):
""" get best cell by comparing factors of cells """
for i in range(len(current_cell["factors"])):
# index 0 = points, 1 = lowest cell, 2 = "in air" cells
for j in range(3):
# if value of best cell factor is smaller than value of
# the same factor in the current cell
# best cell = current cell and break the loop,
# don't compare lower priority factors
if best_cell["factors"][i][j] < current_cell["factors"][i][j]:
return current_cell
# if value of best cell factor is bigger than value of
# the same factor in the current cell
# break loop and don't compare lower priority factors
if best_cell["factors"][i][j] > current_cell["factors"][i][j]:
return best_cell
return best_cell

def get_factors(results):
""" get list of factors represented by results and ordered by priority from highest to lowest """
factors = []
for i in range(conf.inarow - 2):
if i == 1:
# my checker in this cell means my victory two times
factors.append(results[0][0][i] if results[0][0][i][0] > st else (0, 0, 0))
# opponent's checker in this cell means my defeat two times
factors.append(results[0][1][i] if results[0][1][i][0] > st else (0, 0, 0))
# if there are results of a cell one row above current
if len(results) > 1:
# opponent's checker in cell one row above current means my defeat two times
factors.append(results[1][1][i] if -results[1][1][i][0] > st else (0, 0, 0))
# my checker in cell one row above current means my victory two times
factors.append(results[1][0][i] if -results[1][0][i][0] > st else (0, 0, 0))
else:
for j in range(2):
factors.append((0, 0, 0))
else:
for j in range(2):
factors.append((0, 0, 0))
for j in range(2):
factors.append((0, 0, 0))
# consider only if there is no "in air" cells
if results[0][1][i][2] == 0:
# placing opponent's checker in this cell means opponent's victory
factors.append(results[0][1][i])
else:
factors.append((0, 0, 0))
# placing my checker in this cell means my victory
factors.append(results[0][0][i])
# central column priority
factors.append((1 if i == 1 and shift == 0 else 0, 0, 0))
# if there are results of a cell one row above current
if len(results) > 1:
# opponent's checker in cell one row above current means my defeat
factors.append(results[1][1][i])
# my checker in cell one row above current means my victory
factors.append(results[1][0][i])
else:
for j in range(2):
factors.append((0, 0, 0))
# if there are results of a cell two rows above current
if len(results) > 2:
for i in range(conf.inarow - 2):
# my checker in cell two rows above current means my victory
factors.append(results[2][0][i])
# opponent's checker in cell two rows above current means my defeat
factors.append(results[2][1][i])
else:
for i in range(conf.inarow - 2):
for j in range(2):
factors.append((0, 0, 0))
return factors


# define my mark and opponent's mark
my_mark = obs.mark
opp_mark = 2 if my_mark == 1 else 1

# define board as two dimensional array
board = []
for column in range(conf.columns):
board.append([])
for row in range(conf.rows):
board[column].append(obs.board[conf.columns * row + column])

best_cell = None
board_center = conf.columns // 2
inarow_m1 = conf.inarow - 1

# standard amount of points
sp = 1
# "seven" pattern threshold points
st = 1

# start searching for best_cell from board center
x = board_center

# shift to right or left from board center
shift = 0

# searching for best_cell
while x >= 0 and x < conf.columns:
# find first empty cell starting from bottom of the column
y = conf.rows - 1
while y >= 0 and board[x][y] != 0:
y -= 1
# if column is not full
if y >= 0:
# results of current cell and cells above it
results = []
results.append((get_results(x, y, my_mark, 1), get_results(x, y, opp_mark, 1)))
# if possible, get results of a cell one row above current
if (y - 1) >= 0:
results.append((get_results(x, y - 1, my_mark, -1), get_results(x, y - 1, opp_mark, -1)))
# if possible, get results of a cell two rows above current
if (y - 2) >= 0:
results.append((get_results(x, y - 2, my_mark, 1), get_results(x, y - 2, opp_mark, 1)))

# list of factors represented by results
# ordered by priority from highest to lowest
factors = get_factors(results)

# if best_cell is not yet found
if best_cell is None:
best_cell = {
"column": x,
"factors": factors
}
# compare values of factors in best cell and current cell
else:
current_cell = {
"column": x,
"factors": factors
}
best_cell = get_best_cell(best_cell, current_cell)

# shift x to right or left from board center
if shift >= 0: shift += 1
shift *= -1
x = board_center + shift

# return index of the best cell column
return best_cell["column"]

エージェントの評価

ランダム選択の相手との結果と、NegaMax法の相手との結果(平均報酬)を表示します。

[ソース]

1
2
3
4
5
6
def mean_reward(rewards):
return sum(r[0] for r in rewards) / float(len(rewards))

# Run multiple episodes to estimate its performance.
print("My Agent vs Random Agent:", mean_reward(evaluate("connectx", [my_agent, "random"], num_episodes=10)))
print("My Agent vs Negamax Agent:", mean_reward(evaluate("connectx", [my_agent, "negamax"], num_episodes=10)))

[結果]

ランダム相手には完勝しており、NegaMax法の相手との結果もかなり勝ち越しています。

なかなか強いエージェントかとも思ったのですが、Kaggleに提出したところスコア600でした。

強いエージェントとは言えないようです。

Kaggle - ConnectX(3) - 4枚そろえるボードゲーム

Connect Xコンペに関する3回目の記事です。

Connect X

今回は、スコア1029.1を叩き出している「Cell Swarm」というノートブックのエージェントを参考にさせて頂きました。

Cell Swarm

Cell Swarmノートブックのエージェント

エージェントの実装だけ抽出してみます。

Swarmは「群れ」という意味で、Cell Swarmだと「セルの群れ」とか「セルの集まり」とかいう意味でしょうか。

処理はコメントをご参照ください。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def my_agent(obs, conf):

def evaluate_cell(cell):
""" evaluate qualities of the cell """
# セルの品質を評価。パターンを取得して、そのセルのポイント付けをしているみたい。
cell = get_patterns(cell)
cell = calculate_points(cell)
for i in range(1, conf.rows):
cell = explore_cell_above(cell, i)
return cell

def get_patterns(cell):
""" get swarm and opponent's patterns of each axis of the cell """
# 群れと対戦相手のセルの各軸パターンを取得。
ne = get_pattern(cell["x"], lambda z : z + 1, cell["y"], lambda z : z - 1, conf.inarow)
sw = get_pattern(cell["x"], lambda z : z - 1, cell["y"], lambda z : z + 1, conf.inarow)[::-1]
cell["swarm_patterns"]["NE_SW"] = sw + [{"mark": swarm_mark}] + ne
cell["opp_patterns"]["NE_SW"] = sw + [{"mark": opp_mark}] + ne
e = get_pattern(cell["x"], lambda z : z + 1, cell["y"], lambda z : z, conf.inarow)
w = get_pattern(cell["x"], lambda z : z - 1, cell["y"], lambda z : z, conf.inarow)[::-1]
cell["swarm_patterns"]["E_W"] = w + [{"mark": swarm_mark}] + e
cell["opp_patterns"]["E_W"] = w + [{"mark": opp_mark}] + e
se = get_pattern(cell["x"], lambda z : z + 1, cell["y"], lambda z : z + 1, conf.inarow)
nw = get_pattern(cell["x"], lambda z : z - 1, cell["y"], lambda z : z - 1, conf.inarow)[::-1]
cell["swarm_patterns"]["SE_NW"] = nw + [{"mark": swarm_mark}] + se
cell["opp_patterns"]["SE_NW"] = nw + [{"mark": opp_mark}] + se
s = get_pattern(cell["x"], lambda z : z, cell["y"], lambda z : z + 1, conf.inarow)
n = get_pattern(cell["x"], lambda z : z, cell["y"], lambda z : z - 1, conf.inarow)[::-1]
cell["swarm_patterns"]["S_N"] = n + [{"mark": swarm_mark}] + s
cell["opp_patterns"]["S_N"] = n + [{"mark": opp_mark}] + s
return cell

def get_pattern(x, x_fun, y, y_fun, cells_remained):
""" get pattern of marks in direction """
# ある方向へのマークパターンを取得
pattern = []
x = x_fun(x)
y = y_fun(y)
# if cell is inside swarm's borders
# セルが群れの境界内にある場合
if y >= 0 and y < conf.rows and x >= 0 and x < conf.columns:
pattern.append({
"mark": swarm[x][y]["mark"]
})
# amount of cells to explore in this direction
# ある方向へのセルの総数
cells_remained -= 1
if cells_remained > 1:
pattern.extend(get_pattern(x, x_fun, y, y_fun, cells_remained))
return pattern

def calculate_points(cell):
""" calculate amounts of swarm's and opponent's correct patterns and add them to cell's points """
for i in range(conf.inarow - 1):
# inarow = amount of marks in pattern to consider that pattern as correct
inarow = conf.inarow - i
swarm_points = 0
opp_points = 0
# calculate swarm's points and depth
# 群れのポイントと深さを計算
swarm_points = evaluate_pattern(swarm_points, cell["swarm_patterns"]["E_W"], swarm_mark, inarow)
swarm_points = evaluate_pattern(swarm_points, cell["swarm_patterns"]["NE_SW"], swarm_mark, inarow)
swarm_points = evaluate_pattern(swarm_points, cell["swarm_patterns"]["SE_NW"], swarm_mark, inarow)
swarm_points = evaluate_pattern(swarm_points, cell["swarm_patterns"]["S_N"], swarm_mark, inarow)
# calculate opponent's points and depth
# 対戦相手のポイントと深さを計算
opp_points = evaluate_pattern(opp_points, cell["opp_patterns"]["E_W"], opp_mark, inarow)
opp_points = evaluate_pattern(opp_points, cell["opp_patterns"]["NE_SW"], opp_mark, inarow)
opp_points = evaluate_pattern(opp_points, cell["opp_patterns"]["SE_NW"], opp_mark, inarow)
opp_points = evaluate_pattern(opp_points, cell["opp_patterns"]["S_N"], opp_mark, inarow)
# if more than one mark required for victory
# 勝つために1つ以上のマークが必要かどうか
if i > 0:
# swarm_mark or opp_mark priority
# 自分のマークと対戦相手のマークの優先順位
if swarm_points > opp_points:
cell["points"].append(swarm_points)
cell["points"].append(opp_points)
else:
cell["points"].append(opp_points)
cell["points"].append(swarm_points)
else:
cell["points"].append(swarm_points)
cell["points"].append(opp_points)
return cell

def evaluate_pattern(points, pattern, mark, inarow):
""" get amount of points, if pattern has required amounts of marks and zeros """
# saving enough cells for required amounts of marks and zeros
# マーク数と非マーク数の総数を保存する
for i in range(len(pattern) - (conf.inarow - 1)):
marks = 0
zeros = 0
# check part of pattern for required amounts of marks and zeros
# マーク数と非マーク数の総数をチェックする
for j in range(conf.inarow):
if pattern[i + j]["mark"] == mark:
marks += 1
elif pattern[i + j]["mark"] == 0:
zeros += 1
if marks >= inarow and (marks + zeros) == conf.inarow:
return points + 1
return points

def explore_cell_above(cell, i):
""" add positive or negative points from cell above (if it exists) to points of current cell """
# ポジティブなポイントかネガティブなポイントを追加する
if (cell["y"] - i) >= 0:
cell_above = swarm[cell["x"]][cell["y"] - i]
cell_above = get_patterns(cell_above)
cell_above = calculate_points(cell_above)
# points will be positive or negative
# ポイントがポジティブかネガティブか
n = -1 if i & 1 else 1
# if it is first cell above
# 最初のセルの上かどうか
if i == 1:
# add first 4 points of cell_above["points"] to cell["points"]
# 最初の4ポイントを追加する
cell["points"][2:2] = [n * cell_above["points"][1], n * cell_above["points"][0]]
# if it is not potential "seven" pattern in cell and cell_above has more points
if abs(cell["points"][4]) < 2 and abs(cell["points"][4]) < cell_above["points"][2]:
cell["points"][4:4] = [n * cell_above["points"][2]]
# if it is not potential "seven" pattern in cell and cell_above has more points
if abs(cell["points"][5]) < 2 and abs(cell["points"][5]) < cell_above["points"][3]:
cell["points"][5:5] = [n * cell_above["points"][3]]
else:
cell["points"][7:7] = [n * cell_above["points"][3]]
else:
cell["points"][6:6] = [n * cell_above["points"][2], n * cell_above["points"][3]]
cell["points"].append(n * cell_above["points"][4])
cell["points"].append(n * cell_above["points"][5])
else:
cell["points"].extend(map(lambda z : z * n, cell_above["points"]))
else:
cell["points"].extend([0, 0, 0, 0, 0, 0])
return cell

def choose_best_cell(best_cell, current_cell):
""" compare two cells and return the best one """
# 2つのセルを比較しベストなセルを返す
if best_cell is not None:
for i in range(len(best_cell["points"])):
# compare amounts of points of two cells
# 2つのセルの総ポイントを比較する
if best_cell["points"][i] < current_cell["points"][i]:
best_cell = current_cell
break
if best_cell["points"][i] > current_cell["points"][i]:
break
# if ["points"][i] of cells are equal, compare distance to swarm's center of each cell
# もし["points"][i]セルが等しい場合、各セルの群れの中心への距離を比較する
if best_cell["points"][i] > 0:
if best_cell["distance_to_center"] > current_cell["distance_to_center"]:
best_cell = current_cell
break
if best_cell["distance_to_center"] < current_cell["distance_to_center"]:
break
else:
best_cell = current_cell
return best_cell

###############################################################################
# define swarm's and opponent's marks
# 群れと対戦相手のマークを定義
swarm_mark = obs.mark
opp_mark = 2 if swarm_mark == 1 else 1
# define swarm's center
# 群れの中央位置を定義
swarm_center_horizontal = conf.columns // 2
swarm_center_vertical = conf.rows // 2

# define swarm as two dimensional array of cells
# セルの2次元配列として群れを定義
swarm = []
for column in range(conf.columns):
swarm.append([])
for row in range(conf.rows):
cell = {
"x": column,
"y": row,
"mark": obs.board[conf.columns * row + column],
"swarm_patterns": {},
"opp_patterns": {},
"distance_to_center": abs(row - swarm_center_vertical) + abs(column - swarm_center_horizontal),
"points": []
}
swarm[column].append(cell)

best_cell = None
# start searching for best_cell from swarm center
# 群れの中央から最適なセル位置を検索開始
x = swarm_center_horizontal
# shift to right or left from swarm center
# 群れの中央から右か左にシフト
shift = 0

# searching for best_cell
# 最適なセル位置を検索
while x >= 0 and x < conf.columns:
# find first empty cell starting from bottom of the column
# カラムの底位置からマークされていない最初の位置を見つける
y = conf.rows - 1
while y >= 0 and swarm[x][y]["mark"] != 0:
y -= 1
# if column is not full
# カラムがフルでない場合
if y >= 0:
# current cell evaluates its own qualities
# 現在のセルの評価
current_cell = evaluate_cell(swarm[x][y])
# current cell compares itself against best cell
# 現在のセルと最適なセル位置を比較
best_cell = choose_best_cell(best_cell, current_cell)

# shift x to right or left from swarm center
# 中央から右か左にずらす
if shift >= 0:
shift += 1
shift *= -1
x = swarm_center_horizontal + shift

# return index of the best cell column
# 最適なカラム位置のインデックスを返す
return best_cell["x"]

エージェントの評価

ランダム選択の相手との結果と、NegaMax法の相手との結果(平均報酬)を表示します。

[ソース]

1
2
3
4
5
6
def mean_reward(rewards):
return sum(r[0] for r in rewards) / float(len(rewards))

# Run multiple episodes to estimate its performance.
print("My Agent vs Random Agent:", mean_reward(evaluate("connectx", [my_agent, "random"], num_episodes=10)))
print("My Agent vs Negamax Agent:", mean_reward(evaluate("connectx", [my_agent, "negamax"], num_episodes=10)))

[結果]

ランダム相手には完勝しており、NegaMax法の相手との結果も勝ち越しています。

なかなか強いエージェントみたいです。

Kaggle - ConnectX(2) - 4枚そろえるボードゲーム

Connect Xコンペに関する2回目の記事です。

Connect X

今回はエージェントが受け取る情報を確認したいと思います。

環境ルール(Environment Rules)

サンプルソースを見ますと下記のようなコメントが記述されていました。

[サンプルソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def agent(observation, configuration):
# Number of Columns on the Board.(ボードのカラム数)
columns = configuration.columns
# Number of Rows on the Board.(ボードの行数)
rows = configuration.rows
# Number of Checkers "in a row" needed to win.(勝つために並べるコインの数)
inarow = configuration.inarow
# The current serialized Board (rows x columns).(現在のボードの状態。1次元として)
board = observation.board
# Which player the agent is playing as (1 or 2).(どちらのプレイヤーか)
mark = observation.mark

# Return which column to drop a checker (action).(どのカラムに落とすかを返す)
return 0

確認のために、前回作成したソースにエージェントが受け取る情報を表示する処理を追加し実行してみます。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
def my_agent(observation, configuration):
print('observation', observation)
print('configuration', configuration)
print('------------')
from random import choice
return choice([c for c in range(configuration.columns) if observation.board[c] == 0])

env.reset()
# Play as the first agent against default "random" agent.
env.run([my_agent, "random"])
env.render(mode="ipython", width=500, height=450)

[結果]
(一部表示)

サンプルソースの説明に合致した情報が格納されていることが分かります。

そのほかの情報(stepやtimeoutなど)もありますが、必要に応じて参照すればいいと思います。

ポイントとしては、observation.boardでしょうか。

6行7列の配列を1次元配列で表しています。この方が機械学習のデータとして扱いやすいですからね。

Kaggle - ConnectX(1) - 4枚そろえるボードゲーム

今回からConnect Xコンペに参加したいと思います。

Connect X

Connect Xは、ボードゲームの一種で上からコインを交互に落として、縦か横か斜めに4枚コインをそろえた方が勝ちというルールになっています。

ConnectX Getting Startedというスタート練習用のノートブックがありますのでこれを試しに実行してみます。

Kaggle環境のインストール

下記バージョンのKaggle環境をインストールする必要があるようです。

Connect Xはバージョンに依存する環境なんでしょうかね。

[ソース]

1
!pip install 'kaggle-environments>=0.1.6'

[結果]

問題なくインストールすることができました。

Connect X環境の作成

Connect X環境を作成します。

[ソース]

1
2
3
4
from kaggle_environments import evaluate, make, utils

env = make("connectx", debug=True)
env.render()

エージェントの作成

エージェントを作成します。

ここではサンプルとしてランダムにコインを落とす場所を決めているようです。

今後はこのロジックを実装していき勝率を上げていけばいいんですね。

[ソース]

1
2
3
4
# This agent random chooses a non-empty column.
def my_agent(observation, configuration):
from random import choice
return choice([c for c in range(configuration.columns) if observation.board[c] == 0])

エージェントのテスト

上記で作成したエージェントのテストを行います。

ただ相手のロジックもランダムにコインを落とすようなので・・・・今回はただの動作確認用です。

[ソース]

1
2
3
4
env.reset()
# Play as the first agent against default "random" agent.
env.run([my_agent, "random"])
env.render(mode="ipython", width=500, height=450)

[結果]

上記のようなボードが現れてコインが次々に落とされ、コインが4つそろったら終了になります。

アニメーションとして動くのでちょっとおもしろいです。

エージェントのテストと訓練

エージェントのテストと訓練を行います。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
# Play as first position against random agent.
trainer = env.train([None, "random"])

observation = trainer.reset()

while not env.done:
my_action = my_agent(observation, env.configuration)
print("My Action", my_action)
observation, reward, done, info = trainer.step(my_action)
# env.render(mode="ipython", width=100, height=90, header=False, controls=False)
env.render()

[結果]

一回の動作(どの位置にコインを落とすのか)ごとに、その動作がデバッグ表示されます。

エージェントの評価

作成したエージェントを評価します。

[ソース]

1
2
3
4
5
6
def mean_reward(rewards):
return sum(r[0] for r in rewards) / float(len(rewards))

# Run multiple episodes to estimate its performance.
print("My Agent vs Random Agent:", mean_reward(evaluate("connectx", [my_agent, "random"], num_episodes=10)))
print("My Agent vs Negamax Agent:", mean_reward(evaluate("connectx", [my_agent, "negamax"], num_episodes=10)))

[結果]

ランダム選択の相手との結果と、NegaMax法の相手との結果(平均報酬)が表示されます。

ゲーム終了時に勝つと 1 が、負けると 0 が、どちらでもない場合 (引き分け・勝負がついていない) だと 0.5 が報酬として得られるとのことなので、ランダム相手にはたまたま勝ち越し、NegaMax相手には全敗したということになります。

エージェントと対戦

手動でエージェントとの対戦ができます。

[ソース]

1
2
# "None" represents which agent you'll manually play as (first or second player).
env.play([None, "negamax"], width=500, height=450)

[結果]

マス目をクリックしてみたのですが、「Processing…」と表示されたまま動作しませんでした。

提出ファイルの書き出し

提出用のファイルを出力します。

[ソース]

1
2
3
4
5
6
7
8
9
import inspect
import os

def write_agent_to_file(function, file):
with open(file, "a" if os.path.exists(file) else "w") as f:
f.write(inspect.getsource(function))
print(function, "written to", file)

write_agent_to_file(my_agent, "submission.py")

[結果]

問題なく出力されました。

提出ファイルのチェック

提出用のエージェント同士で対戦させて、提出ファイルの妥当性をチェックするようです。

なぜ妥当性をチェックする必要があるかというと、「完全にカプセル化されていてリモート実行できることを確認するため」???とのことでした。

[ソース]

1
2
3
4
5
6
7
8
9
10
# Note: Stdout replacement is a temporary workaround.
import sys
out = sys.stdout
submission = utils.read_file("/kaggle/working/submission.py")
agent = utils.get_last_callable(submission)
sys.stdout = out

env = make("connectx", debug=True)
env.run([agent, agent])
print("Success!" if env.state[0].status == env.state[1].status == "DONE" else "Failed...")

[結果]

エラーになってしまいました。

メソッドがないという意味かと思いますが、仕様の変更があったのでしょうか。

おいおい調査していきたいと思います。

一通りConnect Xの動作方法が分かったので、これから少しずつ調査・実装・改善を行っていきます。

Kaggle - 災害ツイートについての自然言語処理(6) - Baselineモデルでの予測

Natural Language Processing with Disaster Tweetsrに関する6回目の記事です。

Natural Language Processing with Disaster Tweets

今回はBaselineモデルを使って予測を行い、最後にKaggleに結果を提出します。

Baselineモデルの準備

Baselineモデルを作成します。

最適化アルゴリズムとしてはAdamを使います。

Adamは、移動平均で振動を抑制するモーメンタム と 学習率を調整して振動を抑制するRMSProp を組み合わせています。

また、embedding_matrixは単語ごとにベクター値を設定したものです。(前回記事をご参照ください)

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
model = Sequential()

embedding=Embedding(num_words, 100, embeddings_initializer=Constant(embedding_matrix),
input_length=MAX_LEN, trainable=False)

model.add(embedding)
model.add(SpatialDropout1D(0.2))
model.add(LSTM(64, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(1, activation='sigmoid'))

optimzer = Adam(learning_rate=1e-5)

model.compile(loss='binary_crossentropy',optimizer=optimzer,metrics=['accuracy'])

model.summary()

[結果]

データ分割

ツイートデータを配列化したものを(tweet_pad)、訓練用のデータ(正解ラベルとそれ以外)と検証用のデータ(正解ラベルとそれ以外)に分割します。

[ソース]

1
2
3
4
5
6
train=tweet_pad[:tweet.shape[0]]
test=tweet_pad[tweet.shape[0]:]

X_train,X_test,y_train,y_test = train_test_split(train,tweet['target'].values,test_size=0.15)
print('Shape of train',X_train.shape)
print("Shape of Validation ",X_test.shape)

[結果]

学習

分割したデータを使って学習を行います。(少々時間がかかります。)

[ソース]

1
history=model.fit(X_train,y_train,batch_size=4,epochs=15,validation_data=(X_test,y_test),verbose=2)

[結果]

最終的な正解率(検証用)は78.02%となりました。

提出用ファイルの作成

提出用のCSVファイルを作成します。

提出のサンプルファイル(sample_submission.cs)を一旦読み込んで、targetに予測した結果(災害ツイートかどうか)を上書いています。

[ソース]

1
2
3
4
5
6
sample_sub=pd.read_csv('../input/nlp-getting-started/sample_submission.csv')

y_pre = model.predict(test)
y_pre = np.round(y_pre).astype(int).reshape(3263)
sub = pd.DataFrame({'id':sample_sub['id'].values.tolist(), 'target':y_pre})
sub.to_csv('submission.csv',index=False)

[結果]

正解率は78.60%となりました。

それなりの結果かと思いますが、やはりいつもの8割の壁というものを感じてしまいます。

Kaggle - 災害ツイートについての自然言語処理(5) - 単語ベクター化

Natural Language Processing with Disaster Tweetsrに関する5回目の記事です。

Natural Language Processing with Disaster Tweets

今回は、単語ベクター化モデルの一つであるGloVeを試してみます。

単語分割

まずはツイートを単語に分割します。

単語分割する際にword_tokenizeメソッドを使うと、カンマや疑問符といった記号、アポストロフィによる短縮形にもうまく対応することができます。

またisalphaメソッドは文字列中のすべての文字が英字で、かつ 1 文字以上ある場合に真を返します。

[ソース]

1
2
3
4
5
6
7
8
def create_corpus(df):
corpus=[]
for tweet in tqdm(df['text']):
words=[word.lower() for word in word_tokenize(tweet) if((word.isalpha()==1) & (word not in stop))]
corpus.append(words)
return corpus

corpus = create_corpus(df)

[結果]

単語ベクター化

GloVeの学習済みモデルを準備します。3つの次元(50 D ,100 D, 200 D)が用意されていますが、今回は100 Dを使います。

処理後半のpad_sequencesメソッドでは、要素の合わない配列に対して、0 で埋めることで配列のサイズを一致させています。

paddingは前後どちらを埋めるか、truncatingは長いシーケンスの前後どちらを切り詰めるかを指定する引数で、今回は’post’(後ろ)を指定しています。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
embedding_dict={}
with open('../input/glove-global-vectors-for-word-representation/glove.6B.100d.txt','r') as f:
for line in f:
values = line.split()
word = values[0]
vectors = np.asarray(values[1:],'float32')
embedding_dict[word] =vectors
f.close()

MAX_LEN = 50
tokenizer_obj = Tokenizer()
tokenizer_obj.fit_on_texts(corpus)
sequences = tokenizer_obj.texts_to_sequences(corpus)

tweet_pad = pad_sequences(sequences, maxlen=MAX_LEN, truncating='post', padding='post')

word_index = tokenizer_obj.word_index
print('Number of unique words:',len(word_index))

[結果]

ユニークな単語数は20342となりました。

単語ベクター配列作成

単語ごとのベクター値を取得し、結果をembedding_matrixに格納します。

[ソース]

1
2
3
4
5
6
7
8
9
10
num_words = len(word_index) + 1
embedding_matrix = np.zeros((num_words, 100))

for word,i in tqdm(word_index.items()):
if i > num_words:
continue

emb_vec = embedding_dict.get(word)
if emb_vec is not None:
embedding_matrix[i] = emb_vec

[結果]

次回はBaseline Modelを使って、災害ツイートどうかの判定を行い、結果を提出してみます。

Kaggle - 災害ツイートについての自然言語処理(4) - データクレンジング

Natural Language Processing with Disaster Tweetsrに関する4回目の記事です。

Natural Language Processing with Disaster Tweets

今回はデータクレンジングを行います。

学習データと検証データの結合

まずは一括でデータクレンジングするために、学習データと検証データを結合します。

[ソース]

1
2
df = pd.concat([tweet,test])
df.shape

URLの排除

URLの排除を行います。正規表現を使ってURLパターンに合致したものを排除します。

[ソース]

1
2
3
4
5
6
7
8
example = "New competition launched :https://www.kaggle.com/c/nlp-getting-started"

def remove_URL(text):
url = re.compile(r'https?://\S+|www\.\S+')
return url.sub(r'', text)

remove_URL(example)
df['text'] = df['text'].apply(lambda x : remove_URL(x))

[結果]

HTMLタグの排除

HTMLタグも正規表現を使って排除します。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
example = """<div>
<h1>Real or Fake</h1>
<p>Kaggle </p>
<a href="https://www.kaggle.com/c/nlp-getting-started">getting started</a>
</div>"""

def remove_html(text):
html = re.compile(r'<.*?>')
return html.sub(r'',text)
print(remove_html(example))

df['text'] = df['text'].apply(lambda x : remove_html(x))

[結果]

絵文字の排除

絵文字も正規表現を使って排除します。

複数パターンある場合は、リスト型を使ってまとめて指定することができます。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Reference : https://gist.github.com/slowkow/7a7f61f495e3dbb7e3d767f97bd7304b
def remove_emoji(text):
emoji_pattern = re.compile("["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
u"\U00002702-\U000027B0"
u"\U000024C2-\U0001F251"
"]+", flags=re.UNICODE)
return emoji_pattern.sub(r'', text)

remove_emoji("Omg another Earthquake 😔😔")

df['text'] = df['text'].apply(lambda x: remove_emoji(x))

[結果]

句読点の排除

句読点を排除します。下記の方法で文字の変換を行っています。

  1. str.maketrans()でstr.translate()に使える変換テーブルを作成する。
  2. str.translate()で文字列内の文字を変換する。

[ソース]

1
2
3
4
5
6
7
8
def remove_punct(text):
table = str.maketrans('','', string.punctuation)
return text.translate(table)

example = "I am a #king"
print(remove_punct(example))

df['text'] = df['text'].apply(lambda x : remove_punct(x))

[結果]

スペル修正

最後にスペルの修正を行います。

pyspellcheckerというライブラリを使うので、まずこれをインストールしておきます。

[コマンド]

1
!pip install pyspellchecker

[結果]

 

スペルチャックは一旦単語ごとに分解し、スペルミスがあれば正しいスペルに修正し、最後に修正した単語を含めて結合し直しています。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from spellchecker import SpellChecker

spell = SpellChecker()
def correct_spellings(text):
corrected_text = []
misspelled_words = spell.unknown(text.split())
for word in text.split():
if word in misspelled_words:
corrected_text.append(spell.correction(word))
else:
corrected_text.append(word)
return " ".join(corrected_text)

text = "corect me plese"
correct_spellings(text)

[結果]


今回は、英語文字列のクレンジングを行いました。

次回は、単語ベクター化モデルの一つであるGloVeを試してみます。

Kaggle - 災害ツイートについての自然言語処理(3)

Natural Language Processing with Disaster Tweetsrに関する3回目の記事です。

Natural Language Processing with Disaster Tweets

今回は単語ごとの解析を行います。

単語解析

ツイートを単語ごとに分割する関数を定義します。

引数のtargetには災害関連(=1)か災害に関係ない(=0)かを渡します。

[ソース]

1
2
3
4
5
6
7
def create_corpus(target):
corpus=[]

for x in tweet[tweet['target']==target]['text'].str.split():
for i in x:
corpus.append(i)
return corpus

災害に関係のないツイート(target=0)を単語ごとに分割し、出現頻度が多い順にグラフ化します。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
corpus = create_corpus(0)

dic=defaultdict(int)
for word in corpus:
if word in stop:
dic[word] += 1

top = sorted(dic.items(), key=lambda x:x[1], reverse=True)[:10]

x,y=zip(*top)
plt.bar(x,y)

[結果]

the、a、toという冠詞、前置詞の出現が多いようです。


次に災害に関係のあるツイート(target=1)を単語ごとに分割し、出現頻度が多い順にグラフ化します。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
corpus=create_corpus(1)

dic=defaultdict(int)
for word in corpus:
if word in stop:
dic[word]+=1

top=sorted(dic.items(), key=lambda x:x[1],reverse=True)[:10]

x,y=zip(*top)
plt.bar(x,y)

[結果]

こちらもthe、in、ofという冠詞、前置詞の出現が多いようです。

句読点

今度は句読点について調べていきます。

string.punctuation(6行目)は、英数字以外のアスキー文字(句読点含む)を表します。

災害に関係のあるツイートの句読点出現回数をグラフで表示します。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
plt.figure(figsize=(10,5))
corpus = create_corpus(1)

dic = defaultdict(int)
import string
special = string.punctuation
for i in (corpus):
if i in special:
dic[i] += 1

x,y = zip(*dic.items())
plt.bar(x,y)

[結果]


災害に関係のないツイートの句読点出現回数をグラフで表示します。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
plt.figure(figsize=(10,5))
corpus = create_corpus(0)

dic = defaultdict(int)
import string
special = string.punctuation
for i in (corpus):
if i in special:
dic[i] += 1

x, y = zip(*dic.items())
plt.bar(x,y,color='green')

[結果]

どちらも1番目が-(ハイフン)、2番目が|(パイプ)と同じですが、3番目は:(コロン)と?(クエスチョン)と少し違いがあるようです。

共通する単語(ストップワード以外)

最後にストップワードを含まない共通する単語を調べます。

ストップワードとは、自然言語を処理するにあたって処理対象外とする単語のことです。

「at」「of」などの前置詞や、「a」「an」「the」などの冠詞、「I」「He」「She」などの代名詞がストップワードとされます。

ストップワードに含まれない単語を抽出し、グラフ化します。

[ソース]

1
2
3
4
5
6
7
8
9
10
counter = Counter(corpus)
most = counter.most_common()
x=[]
y=[]
for word,count in most[:40]:
if (word not in stop) :
x.append(word)
y.append(count)

sns.barplot(x=y, y=x)

[結果]

まだハイフンやアンパサンドなどがあるので、さらにデータクレンジングをする必要がありそうです。

N-gram解析

N-gram解析とは、対象となるテキストの中で、連続するN個の表記単位(gram)の出現頻度を求める手法です。

そうすることによって、テキスト中の任意の長さの表現の出現頻度パターンなどを知ることができるようになります。

N=2(bigram)として、N-gram解析を行います。

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
def get_top_tweet_bigrams(corpus, n=None):
vec = CountVectorizer(ngram_range=(2, 2)).fit(corpus)
bag_of_words = vec.transform(corpus)
sum_words = bag_of_words.sum(axis=0)
words_freq = [(word, sum_words[0, idx]) for word, idx in vec.vocabulary_.items()]
words_freq =sorted(words_freq, key = lambda x: x[1], reverse=True)
return words_freq[:n]

plt.figure(figsize=(10,5))
top_tweet_bigrams=get_top_tweet_bigrams(tweet['text'])[:10]
x,y=map(list,zip(*top_tweet_bigrams))
sns.barplot(x=y,y=x)

[結果]

これに関しても、URLに関するものが多かったり、前置詞・冠詞の組み合わせが多かったりと、まだまだデータクレンジングが必要そうです。

というわけで、次回はデータクレンジングを行っていきます。

Kaggle - 災害ツイートについての自然言語処理(2)

Natural Language Processing with Disaster Tweetsrに関する2回目の記事です。

Natural Language Processing with Disaster Tweets

今回はツイート(text)に関するEDA(探索的データ解析)を行います。

ツイートの探索的データ解析

基本的なテキスト分析として次の3点を調べます。

  • 文字数
  • 単語数
  • 平均単語レングス

まずは文字数をカウントしグラフ化します。

災害関連かどうかでグラフを分けています。

[ソース]

1
2
3
4
5
6
7
8
9
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,5))
tweet_len = tweet[tweet['target']==1]['text'].str.len()
ax1.hist(tweet_len,color='red')
ax1.set_title('disaster tweets')
tweet_len=tweet[tweet['target']==0]['text'].str.len()
ax2.hist(tweet_len,color='green')
ax2.set_title('Not disaster tweets')
fig.suptitle('Characters in tweets')
plt.show()

[結果]

両グラフともほとんど同じ分布になっています。

120語から140語付近がもっとも度数が多いようです。


次に、単語数をカウントしグラフ化します。

[ソース]

1
2
3
4
5
6
7
8
9
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,5))
tweet_len=tweet[tweet['target']==1]['text'].str.split().map(lambda x: len(x))
ax1.hist(tweet_len,color='red')
ax1.set_title('disaster tweets')
tweet_len=tweet[tweet['target']==0]['text'].str.split().map(lambda x: len(x))
ax2.hist(tweet_len,color='green')
ax2.set_title('Not disaster tweets')
fig.suptitle('Words in a tweet')
plt.show()

[結果]

こちらも似た分布にはなっていますが、単語数15のところの災害関連ツイートが妙に少なくなっていることが分かります。


最後に、平均単語レングスをカウントしグラフ化します。

[ソース]

1
2
3
4
5
6
7
8
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,5))
word=tweet[tweet['target']==1]['text'].str.split().apply(lambda x : [len(i) for i in x])
sns.distplot(word.map(lambda x: np.mean(x)),ax=ax1,color='red')
ax1.set_title('disaster')
word=tweet[tweet['target']==0]['text'].str.split().apply(lambda x : [len(i) for i in x])
sns.distplot(word.map(lambda x: np.mean(x)),ax=ax2,color='green')
ax2.set_title('Not disaster')
fig.suptitle('Average word length in each tweet')

[結果]

こちらもほぼ同じ分布になっているように見えます。

文字数、単語数、平均レングスでは災害関連かどうかを判断するのは難しいのかもしれません。

次回は語尾の単語や句読点、よく使われる単語などを調べていきます。

Kaggle - 災害ツイートについての自然言語処理(1)

今回からはNatural Language Processing with Disaster Tweetsrコンペに参加していきたいと思います。

Natural Language Processing with Disaster Tweets

簡単に説明すると、ツイートが実際の災害に関するものか否かを判定するコンペとのことです。

ちなみに自然言語処理とは、「人間が日常的に使っている自然言語をコンピュータに処理させる一連の技術であり、人工知能と言語学の一分野」ということです。

探索的データ解析 (Exploratory data analysis)

探索的データ分析を行います。つまりどんなデータかを確認していきたいと思います。

まずは必要なライブラリをインポートします。

(Kaggleノートブックで動作確認しています。)

[ソース]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from nltk.corpus import stopwords
from nltk.util import ngrams
from sklearn.feature_extraction.text import CountVectorizer
from collections import defaultdict
from collections import Counter
plt.style.use('ggplot')
stop=set(stopwords.words('english'))
import re
from nltk.tokenize import word_tokenize
import gensim
import string
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from tqdm import tqdm
from keras.models import Sequential
from keras.layers import Embedding,LSTM,Dense,SpatialDropout1D
from keras.initializers import Constant
from sklearn.model_selection import train_test_split
from keras.optimizers import Adam

訓練データと検証データを読み込みます。

訓練データの一部を表示してデータを確認します。

[ソース]

1
2
3
tweet= pd.read_csv('../input/nlp-getting-started/train.csv')
test=pd.read_csv('../input/nlp-getting-started/test.csv')
tweet.head(10)

[結果]

keywordとlocationが欠損値ばかりです。

targetが予測すべき項目(1が災害関連。0が災害に無関係)なので、ほとんど残りのtext(ツイート)から予測することになりそうです。

検証データの一部も表示してみます、

[ソース]

1
test.head(10)

[結果]

keyword,location,text項目はありますが、target項目がありませんので、このデータを予測する必要があることが確認できました。

分布確認

targetごとのデータ数を分布図で確認してみます。

[ソース]

1
2
3
x=tweet.target.value_counts()
sns.barplot(x.index,x)
plt.gca().set_ylabel('samples')

[結果]

災害関連ツイート(=1)よりも災害に関連しないツイート(=0)の方が多いことが分かります。

次回からはツイート(text)に関するEDA(探索的データ解析)を行っていきます。


Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×