Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 9f72da3

Browse files
author
YuHang
committed
hotfix: 'train' phase batch_norm causes unstable prediction
1 parent ff4afa8 commit 9f72da3

File tree

7 files changed

+57
-46
lines changed

7 files changed

+57
-46
lines changed

‎Network.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,11 @@ def close(self):
103103
usage: queue prediction, self-play
104104
'''
105105
def run_many(self,imgs):
106+
imgs = np.asarray(imgs).astype(np.float32)
106107
imgs[:][...,16] = (imgs[:][...,16]-0.5)*2
107-
feed_dict = {self.imgs:imgs,self.model.training: False}
108+
# set high temperature to counter strong move bias?
109+
# set model batch_norm
110+
feed_dict = {self.imgs:imgs,self.model.training: False ,self.model.temp: 1.}
108111
move_probabilities,value = self.sess.run([self.model.prediction,self.model.value],feed_dict=feed_dict)
109112

110113
# with multi-gpu, porbs and values are separated in each outputs
@@ -195,6 +198,7 @@ def test(self,test_data, proportion=0.1,force_save_model=False):
195198
test_acc += ac
196199
test_result_acc += result_acc
197200
n_batch += 1
201+
logger.debug(f'Test accuaracy: {test_acc}')
198202

199203
tot_test_loss = test_loss / (n_batch-1e-2)
200204
tot_test_acc = test_acc / (n_batch-1e-2)

‎main.py‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040
parser.add_argument('--n_resid_units', type=int, default=6)
4141
parser.add_argument('--n_gpu', type=int, default=1)
4242
parser.add_argument('--dataset', dest='processed_dir',default='./processed_data')
43-
parser.add_argument('--model_path',dest='load_model_path',default='./savedmodels')
43+
parser.add_argument('--model_path',dest='load_model_path',default='./savedmodels')#'./savedmodels'
4444
parser.add_argument('--model_type',dest='model',default='full',\
4545
help='choose residual block architecture {original,elu,full}')
4646
parser.add_argument('--optimizer',dest='opt',default='adam')
4747
parser.add_argument('--gtp_policy',dest='gpt_policy',default='mctspolicy',help='choose gtp bot player')#random,mctspolicy
48-
parser.add_argument('--num_playouts',type=int,dest='num_playouts',default=200,help='The number of MC search per move, the more the better.')
48+
parser.add_argument('--num_playouts',type=int,dest='num_playouts',default=1600,help='The number of MC search per move, the more the better.')
49+
parser.add_argument('--selfplay_games_per_epoch',type=int,dest='selfplay_games_per_epoch',default=25000)
4950
parser.add_argument('--mode',dest='MODE', default='train',help='among selfplay, gtp and train')
5051
FLAGS = parser.parse_args()
5152

@@ -133,7 +134,7 @@ def selfplay(flags=FLAGS,hps=HPS):
133134
"""set the batch size to -1==None"""
134135
flags.n_batch = -1
135136
net = Network(flags,hps)
136-
Worker = SelfPlayWorker(net)
137+
Worker = SelfPlayWorker(net,flags)
137138

138139
def train(epoch:int):
139140
lr = schedule_lrn_rate(epoch)
@@ -160,7 +161,7 @@ def evaluate_testset():
160161
#evaluate_testset()
161162

162163
"""Evaluate against best model"""
163-
evaluate_generations()
164+
#evaluate_generations()
164165

165166
logger.info(f'Global epoch {g_epoch} finish.')
166167
logger.info('Now, I am the Master! 现在,请叫我棋霸!')

‎model/APV_MCTS.py‎

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ async def prediction_worker(self):
6565
bulk_features = np.asarray([item.feature for item in item_list])
6666
policy_ary, value_ary = self.run_many(bulk_features)
6767
for p, v, item in zip(policy_ary, value_ary, item_list):
68+
'''
69+
greedy_move = divmod(np.argmax(p),go.N)
70+
logger.debug(f'Greedy move: {greedy_move}')
71+
'''
6872
item.future.set_result((p, v))
6973

7074
async def push_queue(self, features):
@@ -121,7 +125,7 @@ def virtual_loss_undo(self):
121125
self.W += 3
122126

123127
def is_expanded(self):
124-
return self.position is not None
128+
return self.children.get(None) is not None
125129

126130
#@profile
127131
def compute_position(self):
@@ -161,29 +165,26 @@ def move_prob(self):
161165
prob /= np.sum(prob) # ensure 1.
162166
return prob
163167

168+
def shift_node(self,move,pos=None):
169+
child = self.children[move]
170+
self.parent,self.move,self.prior,self.position,\
171+
self.children,self.U,self.N,self.W = self, child.move,\
172+
child.prior,child.position if pos is None else pos,\
173+
child.children,child.U,\
174+
child.N,child.W
175+
164176
def suggest_move(self, position):
165177

166178
move_prob = self.suggest_move_prob(position)
167179

168-
'''
169-
logger.debug(bulk_extract_features([position]).shape)
170-
move_prob,value = self.api.run_many(bulk_extract_features([position]))
171-
move_prob = move_prob[0]
172-
value = value[0]
173-
greedy_move = divmod(np.argmax(move_prob),go.N)
174-
prob = move_prob[np.argmax(move_prob)]
175-
logger.debug(f'Greedy move is: {greedy_move} with prob {prob} at game step {position.n}')
176-
win_rate = (value)/2+0.5
177-
'''
178-
179180
on_board_move_prob = np.reshape(move_prob[:-1],(go.N,go.N))
180181
if position.n < 30:
181182
move = select_weighted_random(position, on_board_move_prob)
182183
else:
183184
move = select_most_likely(position, on_board_move_prob)
184185

185186
player = 'B' if position.to_play==1 else 'W'
186-
win_rate = self.children[move].Q
187+
win_rate = self.children[move].Q/2+0.5
187188
logger.info(f'Win rate for player {player} is {win_rate:2f}')
188189

189190
return move
@@ -197,12 +198,12 @@ def suggest_move_prob(self, position):
197198
logger.debug(f'Expadning Root Node...')
198199

199200
move_probs,_ = self.api.run_many(bulk_extract_features([position]))
200-
201+
'''
201202
move_prob = move_probs[0]
202203
greedy_move = divmod(np.argmax(move_prob),go.N)
203204
prob = move_prob[np.argmax(move_prob)]
204205
logger.debug(f'Greedy move is: {greedy_move} with prob {prob} at game step {position.n}')
205-
206+
'''
206207
self.position = position
207208
self.expand(move_probs[0])
208209

@@ -231,7 +232,9 @@ async def start_tree_search(self):
231232
now_expanding.add(self)
232233

233234
# compute leaf node position on the fly
234-
pos = self.compute_position()
235+
pos = self.position
236+
if pos is None:
237+
pos = self.compute_position()
235238

236239
if pos is None:
237240
#print("illegal move!", file=sys.stderr)

‎model/APV_MCTS_C.pyx‎

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ class NetworkAPI(object):
6565
bulk_features = np.asarray([item.feature for item in item_list])
6666
policy_ary, value_ary = self.run_many(bulk_features)
6767
for p, v, item in zip(policy_ary, value_ary, item_list):
68+
'''
69+
greedy_move = divmod(np.argmax(p),go.N)
70+
logger.debug(f'Greedy move: {greedy_move}')
71+
'''
6872
item.future.set_result((p, v))
6973

7074
async def push_queue(self, features):
@@ -121,7 +125,7 @@ class MCTSPlayerMixin(object):
121125
self.W += 3
122126

123127
def is_expanded(self):
124-
return self.position is not None
128+
return self.children.get(None) is not None
125129

126130
#@profile
127131
def compute_position(self):
@@ -161,30 +165,28 @@ class MCTSPlayerMixin(object):
161165
prob /= np.sum(prob) # ensure 1.
162166
return prob
163167

168+
def shift_node(self,move,pos=None):
169+
child = self.children[move]
170+
self.parent,self.move,self.prior,self.position,\
171+
self.children,self.U,self.N,self.W = self, child.move,\
172+
child.prior,child.position if pos is None else pos,\
173+
child.children,child.U,\
174+
child.N,child.W
175+
164176
def suggest_move(self, position):
165177

166178
move_prob = self.suggest_move_prob(position)
167179

168-
'''
169-
logger.debug(bulk_extract_features([position]).shape)
170-
move_prob,value = self.api.run_many(bulk_extract_features([position]))
171-
move_prob = move_prob[0]
172-
value = value[0]
173-
greedy_move = divmod(np.argmax(move_prob),go.N)
174-
prob = move_prob[np.argmax(move_prob)]
175-
logger.debug(f'Greedy move is: {greedy_move} with prob {prob} at game step {position.n}')
176-
win_rate = (value)/2+0.5
177-
'''
178-
179180
on_board_move_prob = np.reshape(move_prob[:-1],(go.N,go.N))
180181
if position.n < 30:
181182
move = select_weighted_random(position, on_board_move_prob)
182183
else:
183184
move = select_most_likely(position, on_board_move_prob)
184185

185186
player = 'B' if position.to_play==1 else 'W'
186-
win_rate = self.children[move].Q
187-
logger.info(f'Win rate for player {player} is {win_rate:2f}')
187+
win_rate = self.children[move].Q/2+0.5
188+
win_rate = 'New Visit' if win_rate == 0 else win_rate
189+
logger.info(f'Win rate for player {player} is {win_rate}')
188190

189191
return move
190192

@@ -197,12 +199,12 @@ class MCTSPlayerMixin(object):
197199
logger.debug(f'Expadning Root Node...')
198200

199201
move_probs,_ = self.api.run_many(bulk_extract_features([position]))
200-
202+
'''
201203
move_prob = move_probs[0]
202204
greedy_move = divmod(np.argmax(move_prob),go.N)
203205
prob = move_prob[np.argmax(move_prob)]
204206
logger.debug(f'Greedy move is: {greedy_move} with prob {prob} at game step {position.n}')
205-
207+
'''
206208
self.position = position
207209
self.expand(move_probs[0])
208210

@@ -231,7 +233,9 @@ class MCTSPlayerMixin(object):
231233
now_expanding.add(self)
232234

233235
# compute leaf node position on the fly
234-
pos = self.compute_position()
236+
pos = self.position
237+
if pos is None:
238+
pos = self.compute_position()
235239

236240
if pos is None:
237241
#print("illegal move!", file=sys.stderr)

‎model/SelfPlayWorker.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ def timer(message):
2020

2121
class SelfPlayWorker(object):
2222

23-
def __init__(self,net):
23+
def __init__(self,net,flags):
2424
self.net = net
2525
self.N_games_per_train = 1#10
26-
self.N_games = 1#25000
27-
self.playouts = 200#1600
26+
self.N_games = flags.selfplay_games_per_epoch#25000
27+
self.playouts = flags.num_playouts#1600
2828
self.position = go.Position(to_play=go.BLACK)
2929
self.final_position_collections = []
3030
self.dicard_game_threshold = 30 # number of moves that is considered to resign too early
31-
self.resign_threshold = -0.75
31+
self.resign_threshold = -0.8
3232
self.resign_delta = 0.05
3333
self.total_resigned_games = 0
3434
self.total_false_resigned_games = 0

‎utils/gtp_wrapper.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def make_move(self, color, vertex):
5151
self.accomodate_out_of_turn(color)
5252
try:
5353
self.position.play_move(coords,mutate=True, color=translate_gtp_colors(color))
54+
self.shift_node(move=coords,pos=self.position)
5455
except:
5556
return False
5657
return True

‎utils/strategies.py‎

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ def simulate_rival_games_mcts(policy1, policy2, positions):
148148
move = mc_root.suggest_move(pos)
149149
pos.play_move(move, mutate=True, move_prob=policy.move_prob())
150150
# shift to child node
151-
mc_root = mc_root.children[move]
152-
mc_root.parent = None
151+
mc_root.shift_node(move,pos)
153152

154153
# TODO: implement proper end game
155154
for pos in positions:
@@ -183,9 +182,8 @@ def game_end_condition():
183182
move = mc_root.suggest_move(position)
184183
position.play_move(move, mutate=True, move_prob=mc_root.move_prob())
185184
# shift to child node
186-
mc_root=mc_root.children[move]
185+
mc_root.shift_node(move,position)
187186
logger.debug(f'Move at step {position.n} is {move}')
188-
mc_root.parent = None
189187

190188
# check resign
191189
if resign_condition():

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /