-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNode.cpp
140 lines (127 loc) · 3.71 KB
/
Node.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
//
// Created by zhc7 on 24-5-16.
//
#include <cmath>
#include <algorithm>
#include "Node.h"
#include "Judge.h"
#include "mem.h"
Node *Node::root = nullptr;
void Node::handleMustWin(char winner) {
state.mustWin = winner;
}
float Node::ucbValue(const float sqrtLogParentVisit) const {
return winRate + sqrtLogParentVisit * revSqrtVisit;
}
Node *Node::select() {
Node *selected = nullptr;
float bestValue = -1;
char candidateMustWin = -1;
const float sqrtLogVisit = UCB_C * sqrtf(2 * log(visits));
for (Node *child: children) {
if (child != nullptr) {
// have win -> must win; all lose -> must lose; lose and tie -> must tie
candidateMustWin = std::max(candidateMustWin, child->state.mustWin);
if (child->state.mustWin == state.nextTurn) {
handleMustWin(state.nextTurn);
return this;
}
if (child->state.mustWin != 0) {
continue;
}
const float uctValue = child->ucbValue(sqrtLogVisit);
if (uctValue > bestValue) {
bestValue = uctValue;
selected = child;
}
}
}
if (selected == nullptr && candidateMustWin != -1) {
handleMustWin(candidateMustWin);
return this;
}
return selected;
}
int Node::expand() {
// find avail children actions
const short avail = state.avail;
int j = 0;
int indexes[12];
for (int i = 0; i < State::N; i++) {
if (!(avail & (1 << i)) || children[i] != nullptr) {
continue;
}
indexes[j++] = i;
}
if (j == 1) {
isLeaf = false;
}
const int target = indexes[random() % j];
return expandAction(target);
}
int Node::expandAction(const int target) {
const auto child = getNode();
const int x = state.top[target] - 1;
state.step(target, child->state);
children[target] = child;
// check if must win
HeavyBoard b(child->state.board);
if (win(x, target, b, state.nextTurn)) {
child->state.mustWin = state.nextTurn;
state.mustWin = state.nextTurn;
child->update(state.nextTurn);
return state.nextTurn;
}
const auto turn = child->state.nextTurn;
char only_child = -1;
for (int i = 0; i < State::N; i++) {
const int x = child->state.top[i] - 1;
if (x >= 0) {
if (b.mustWin(x, i, turn)) {
child->state.mustWin = turn;
break;
}
if (b.mustWin(x, i, 3 - turn)) {
if (only_child == -1) {
only_child = i;
} else {
only_child = -1; // use normal simulation branch later
child->state.mustWin = 3 - turn;
}
}
}
}
int sim_winner;
if (only_child == -1 || child->state.mustWin == turn) {
sim_winner = child->state.simulate(b, target);
} else {
sim_winner = child->expandAction(only_child);
child->isLeaf = false;
}
child->update(sim_winner);
if (child->state.mustWin == state.nextTurn) {
state.mustWin = state.nextTurn;
}
return sim_winner;
}
void Node::update(const int winner) {
visits++;
playerWins += winner == 2;
playerWins += (winner == 3) * 0.5;
winRate = playerWins / visits;
revSqrtVisit = 1 / sqrtf(visits);
if (state.nextTurn == 2) {
// parent's next turn is opponent's turn
winRate = 1 - winRate;
}
}
Node *Node::pick(int y) {
// clear other children
for (int i = 0; i < State::N; i++) {
if (i != y) {
freeNode(children[i]);
children[i] = nullptr;
}
}
return children[y];
}