-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnect_four.py
60 lines (44 loc) · 1.57 KB
/
connect_four.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import game_utils
class ConnectFour(object):
def __init__(self):
self.state = game_utils.initialize()
self.ledger = []
def __str__(self):
return str(self.state)
def is_win(self):
return game_utils.is_win(self.state)
def is_end(self):
return game_utils.is_end(self.state)
def step(self, action):
col_id, player_id = action
observation = game_utils.step(self.state, col_id, player_id, in_place=True)
has_won = self.is_win()
no_more_actions = len(self.get_valid_col_id()) == 0
terminated = has_won or no_more_actions
if has_won:
# Win
reward = 1
elif no_more_actions:
# Draw
reward = 0.5
else:
# Lose
reward = 0
row_id = np.where(self.state[:,col_id] == player_id)[0][0]
self.ledger.append((row_id, col_id, player_id))
return observation, reward, terminated
def get_state(self):
read_only_state = self.state.view()
read_only_state.flags.writeable = False
return read_only_state
def size(self):
return self.state.shape
def get_valid_col_id(self):
return game_utils.get_valid_col_id(self.state)
def is_valid_col_id(self, col_id):
return game_utils.is_valid_col_id(self.state, col_id)
def get_cell(self, row_id, col_id):
return self.state[row_id][col_id]
def get_ledger_actions(self):
return [ c for _,c,_ in self.ledger]