From 4b1d168ee71b2ee8b07b3fa0990bf75003ad611a Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 23 Jan 2021 21:17:04 +0900 Subject: [PATCH 01/22] experiment: MuZero implementation --- handyrl/generation.py | 2 + handyrl/model.py | 97 +++++++++++++++++++++++++++++++++++++++++++ handyrl/train.py | 44 +++++++++++++------- 3 files changed, 128 insertions(+), 15 deletions(-) diff --git a/handyrl/generation.py b/handyrl/generation.py index 0564e771..7c7ccabf 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -48,6 +48,7 @@ def generate(self, models, args): action_mask = np.ones_like(p) * 1e32 action_mask[legal_actions] = 0 p_turn = p - action_mask + sp_turn = model.inference(obs)['policy'] moment['observation'][player] = obs moment['value'][player] = v @@ -59,6 +60,7 @@ def softmax(x): moment['policy'] = p_turn moment['action_mask'] = action_mask + moment['supervised_policy'] = sp_turn moment['turn'] = self.env.turn() moment['action'] = action moments.append(moment) diff --git a/handyrl/model.py b/handyrl/model.py index 8d88d330..b62ad517 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from .util import map_r +from .search import MonteCarloTree def to_torch(x, transpose=False, unsqueeze=None): @@ -271,3 +272,99 @@ def forward(self, x, hidden=None): h_v = self.head_v(h) return {'policy': h_p, 'value': torch.tanh(h_v)} + + +class MuZero(BaseModel): + class Representation(nn.Module): + ''' Conversion from observation to inner abstract state ''' + def __init__(self, input_dim, layers, filters): + super().__init__() + self.layer0 = Conv(input_dim, filters, 3, bn=True) + self.blocks = nn.ModuleList([WideResidualBlock(filters, 3, bn=True) for _ in range(layers)]) + + def forward(self, x): + h = F.relu(self.layer0(x)) + for block in self.blocks: + h = block(h) + return h + + def inference(self, x): + self.eval() + with torch.no_grad(): + rp = self(torch.from_numpy(x).unsqueeze(0)) + return rp.cpu().numpy().squeeze(0) + + class Prediction(nn.Module): + ''' Policy and value prediction from inner abstract state ''' + def __init__(self, internal_size, action_length): + super().__init__() + self.head_p = Head(internal_size, 4, action_length) + self.head_v = Head(internal_size, 4, 1) + + def forward(self, rp): + p = self.head_p(rp) + v = self.head_v(rp) + return p, torch.tanh(v) + + def inference(self, rp): + self.eval() + with torch.no_grad(): + p, v = self(torch.from_numpy(rp).unsqueeze(0)) + return p.cpu().numpy().squeeze(0), v.cpu().numpy().squeeze(0) + + class Dynamics(nn.Module): + '''Abstract state transition''' + def __init__(self, rp_shape, action_length, action_filters, layers): + super().__init__() + self.action_shape = action_filters, rp_shape[1], rp_shape[2] + filters = rp_shape[0] + self.action_embedding = nn.Embedding(action_length, embedding_dim=np.prod(self.action_shape)) + self.layer0 = Conv(rp_shape[0] + self.action_shape[0], filters, 3, bn=True) + self.blocks = nn.ModuleList([WideResidualBlock(filters, 3, bn=True) for _ in range(layers)]) + + def forward(self, rp, a): + arp = self.action_embedding(a).view(-1, *self.action_shape) + h = torch.cat([rp, arp], dim=1) + h = self.layer0(h) + for block in self.blocks: + h = block(h) + return h + + def inference(self, rp, a): + self.eval() + with torch.no_grad(): + rp = self(torch.from_numpy(rp).unsqueeze(0), torch.from_numpy(a).unsqueeze(0)) + return rp.cpu().numpy().squeeze(0) + + def __init__(self, env, args={}): + super().__init__(env, args) + self.input_size = env.observation().shape + layers, filters = args.get('layers', 3), args.get('filters', 32) + internal_size = (filters, *self.input_size[1:]) + + self.nets = nn.ModuleDict({ + 'representation': self.Representation(self.input_size[0], layers, filters), + 'prediction': self.Prediction(internal_size, self.action_length), + 'dynamics': self.Dynamics(internal_size, self.action_length, 2, layers), + }) + + def init_hidden(self, batch_size=None): + return {} + + def forward(self, x, hidden, action=None): + if 'representation' not in hidden: + rp = self.nets['representation'](x) + else: + rp = hidden['representation'] + p, v = self.nets['prediction'](rp) + outputs = {'policy': p, 'value': v} + + if action is not None: + next_rp = self.nets['dynamics'](rp, action) + outputs['hidden'] = {'representation': next_rp} + return outputs + + def inference(self, x, hidden=None, num_simulations=30): + tree = MonteCarloTree(self.nets, {}) + p, v = tree.think(x, num_simulations) + return {'policy': p, 'value': v} \ No newline at end of file diff --git a/handyrl/train.py b/handyrl/train.py index c86da9ef..8e6c7336 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -22,7 +22,8 @@ from .environment import prepare_env, make_env from .util import map_r, bimap_r, trimap_r, rotate, type_r from .model import to_torch, to_gpu_or_not, RandomModel -from .model import SimpleConv2DModel as DefaultModel +#from .model import SimpleConv2DModel as DefaultModel +from .model import MuZero as DefaultModel from .losses import compute_target from .connection import MultiProcessWorkers from .connection import accept_socket_connections @@ -64,6 +65,7 @@ def make_batch(episodes, args): # datum that is not changed by training configuration p = np.array([m['policy'] for m in moments]) + sp = np.array([m['supervised_policy'] for m in moments]) v = np.array( [[m['value'][player] or 0 for player in players] for m in moments], dtype=np.float32 @@ -84,7 +86,6 @@ def make_batch(episodes, args): amask = np.array([m['action_mask'] for m in moments]) # action mask act = np.array([m['action'] for m in moments]).reshape(-1, 1) - progress = np.arange(ep['start'], ep['end'], dtype=np.float32) / ep['total'] # pad each array if step length is short @@ -93,6 +94,7 @@ def make_batch(episodes, args): obs = map_r(obs, lambda o: np.pad(o, [(0, pad_len)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0)) p = np.pad(p, [(0, pad_len), (0, 0)], 'constant', constant_values=0) v = np.concatenate([v, np.tile(oc, [pad_len, 1])]) + sp = np.pad(sp, [(0, pad_len), (0, 0)], 'constant', constant_values=0) act = np.pad(act, [(0, pad_len), (0, 0)], 'constant', constant_values=0) rew = np.pad(rew, [(0, pad_len), (0, 0)], 'constant', constant_values=0) ret = np.pad(ret, [(0, pad_len), (0, 0)], 'constant', constant_values=0) @@ -103,13 +105,14 @@ def make_batch(episodes, args): progress = np.pad(progress, [(0, pad_len)], 'constant', constant_values=1) obss.append(obs) - datum.append((p, v, act, oc, rew, ret, tmask, omask, amask, progress)) + datum.append((p, v, sp, act, oc, rew, ret, tmask, omask, amask, progress)) - p, v, act, oc, rew, ret, tmask, omask, amask, progress = zip(*datum) + p, v, sp, act, oc, rew, ret, tmask, omask, amask, progress = zip(*datum) obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o))) p = to_torch(np.array(p)) v = to_torch(np.array(v)) + sp = to_torch(np.array(sp)) act = to_torch(np.array(act)) oc = to_torch(np.array(oc)) rew = to_torch(np.array(rew)) @@ -123,6 +126,7 @@ def make_batch(episodes, args): return { 'observation': obs, 'policy': p, 'value': v, + 'supervised_policy': sp, 'action': act, 'outcome': oc, 'reward': rew, 'return': ret, 'episode_mask': emask, @@ -155,21 +159,25 @@ def forward_prediction(model, hidden, batch, obs_mode): outputs = {} for t in range(batch['turn_mask'].size(1)): obs = map_r(observations, lambda o: o[:, t].reshape(-1, *o.size()[3:])) # (..., B * P, ...) - omask_ = batch['observation_mask'][:, t] - omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (len(h.size()) - 2)))) - hidden_ = bimap_r(hidden, omask, lambda h, m: h * m) # (..., B, P, ...) - if obs_mode: - hidden_ = map_r(hidden_, lambda h: h.view(-1, *h.size()[2:])) # (..., B * P, ...) - else: - hidden_ = map_r(hidden_, lambda h: h.sum(1)) # (..., B * 1, ...) - outputs_ = model(obs, hidden_) + action = batch['action'][:, t] + #omask_ = batch['observation_mask'][:, t] + #omask_ = omask_.sum(-1, keepdim=True) # common hidden state + #omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (len(h.size()) - 2)))) + #hidden_ = bimap_r(hidden, omask, lambda h, m: h * m) # (..., B, P, ...) + #if obs_mode: + # hidden_ = map_r(hidden_, lambda h: h.view(-1, *h.size()[2:])) # (..., B * P, ...) + #else: + # hidden_ = map_r(hidden_, lambda h: h.sum(1)) # (..., B * 1, ...) + hidden_ = hidden + outputs_ = model(obs, hidden_, action) for k, o in outputs_.items(): if k == 'hidden': next_hidden = outputs_['hidden'] else: outputs[k] = outputs.get(k, []) + [o] - next_hidden = bimap_r(next_hidden, hidden, lambda nh, h: nh.view(h.size(0), -1, *h.size()[2:])) # (..., B, P or 1, ...) - hidden = trimap_r(hidden, next_hidden, omask, lambda h, nh, m: h * (1 - m) + nh * m) + #next_hidden = bimap_r(next_hidden, hidden, lambda nh, h: nh.view(h.size(0), -1, *h.size()[2:])) # (..., B, P or 1, ...) + #hidden = trimap_r(hidden, next_hidden, omask, lambda h, nh, m: h * (1 - m) + nh * m) + hidden = next_hidden outputs = {k: torch.stack(o, dim=1) for k, o in outputs.items() if o[0] is not None} for k, o in outputs.items(): @@ -199,15 +207,21 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba turn_advantages = total_advantages.mul(tmasks).sum(-1, keepdim=True) losses['p'] = (-log_selected_policies * turn_advantages).sum() + spolicies = batch['supervised_policy'] + losses['sp'] = (spolicies * (torch.clamp(spolicies, 1e-10, 1).log() - F.log_softmax(outputs['policy'], dim=-1))).sum(-1, keepdim=True).mul(tmasks).sum() if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 + losses['sv'] = ((outputs['value'] - batch['outcome']) ** 2).mul(omasks).sum() / 2 if 'return' in outputs: losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum() entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1)) losses['ent'] = entropy.sum() - base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) + teacher_weight = 1 + reinforce_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) + supervised_loss = losses['sp'] + losses.get('sv', 0) + base_loss = (1 - teacher_weight) * reinforce_loss + teacher_weight * supervised_loss entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] losses['total'] = base_loss + entropy_loss From 0638bfb627d799cfbd47eaa4b4923b0a1030adee Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 23 Jan 2021 21:20:32 +0900 Subject: [PATCH 02/22] fix: add handyrl/search.py --- handyrl/search.py | 105 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100755 handyrl/search.py diff --git a/handyrl/search.py b/handyrl/search.py new file mode 100755 index 00000000..a33c40b1 --- /dev/null +++ b/handyrl/search.py @@ -0,0 +1,105 @@ +# Copyright (c) 2020 DeNA Co., Ltd. +# Licensed under The MIT License [see LICENSE for details] + +# tree search + +import copy +import time + +import numpy as np + + +def softmax(x): + x = np.exp(x - np.max(x, axis=-1)) + return x / x.sum(axis=-1) + + +class Node: + '''Search result of one abstract (or root) state''' + def __init__(self, p, v): + self.p, self.v = p, v + self.n, self.q_sum = np.zeros_like(p), np.zeros_like(p) + self.n_all, self.q_sum_all = 1, v / 2 # prior + + def update(self, action, q_new): + # Update + self.n[action] += 1 + self.q_sum[action] += q_new + + # Update overall stats + self.n_all += 1 + self.q_sum_all += q_new + + +class MonteCarloTree: + '''Monte Carlo Tree Search''' + def __init__(self, model, args): + self.model = model + self.nodes = {} + self.args = {} + + def search(self, rp, path): + # Return predicted value from new state + key = '|' + ' '.join(map(str, path)) + if key not in self.nodes: + p, v = self.model['prediction'].inference(rp) + p, v = softmax(p), v[0] + self.nodes[key] = Node(p, v) + return v + + # State transition by an action selected from bandit + node = self.nodes[key] + p = node.p + if len(path) == 0: + # Add noise to policy on the root node + p = 0.75 * p + 0.25 * np.random.dirichlet([0.15] * len(p)) + # On the root node, we choose action only from legal actions + p /= p.sum() + 1e-16 + + n, q_sum = 1 + node.n, node.q_sum_all / node.n_all + node.q_sum + ucb = q_sum / n + 2.0 * np.sqrt(node.n_all) * p / n # PUCB formula + best_action = np.argmax(ucb) + + # Search next state by recursively calling this function + next_rp = self.model['dynamics'].inference(rp, np.array([best_action])) + path.append(best_action) + q_new = -self.search(next_rp, path) # With the assumption of changing player by turn + node.update(best_action, q_new) + + return q_new + + def think(self, root_obs, num_simulations, temperature=0, env=None, show=False): + # End point of MCTS + start, prev_time = time.time(), 0 + for _ in range(num_simulations): + self.search(self.model['representation'].inference(root_obs), []) + + # Display search result on every second + if show: + tmp_time = time.time() - start + if int(tmp_time) > int(prev_time): + prev_time = tmp_time + root, pv = self.nodes['|'], self.pv(env) + print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s' + % (tmp_time, env.action2str(pv[0]), root.q_sum[pv[0]] / root.n[pv[0]], + root.n[pv[0]], root.n_all, ' '.join([env.action2str(a) for a in pv]))) + + # Return probability distribution weighted by the number of simulations + root = self.nodes['|'] + n = root.n + 1 + n = (n / np.max(n)) ** (1 / (temperature + 1e-8)) + p = n / n.sum() + v = root.q_sum_all / root.n_all + return p, v + + def pv(self, env): + # Return principal variation (action sequence which is considered as the best) + s, pv_seq = copy.deepcopy(env), [] + while True: + key = '|' + if key not in self.nodes or self.nodes[key].n.sum() == 0: + break + best_action = sorted([(a, self.nodes[key].n[a]) for a in env.legal_actions()], key=lambda x: -x[1])[0][0] + pv_seq.append(best_action) + s.play(best_action) + return pv_seq \ No newline at end of file From ee33e50629c9d14e2d758f23ebb47df1c00a0ab8 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 24 Jan 2021 22:23:30 +0900 Subject: [PATCH 03/22] fix: policy target and action mask for target policy --- handyrl/search.py | 5 ++--- handyrl/train.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/handyrl/search.py b/handyrl/search.py index a33c40b1..aca8e5b5 100755 --- a/handyrl/search.py +++ b/handyrl/search.py @@ -68,7 +68,7 @@ def search(self, rp, path): return q_new - def think(self, root_obs, num_simulations, temperature=0, env=None, show=False): + def think(self, root_obs, num_simulations, temperature=1.0, env=None, show=False): # End point of MCTS start, prev_time = time.time(), 0 for _ in range(num_simulations): @@ -87,8 +87,7 @@ def think(self, root_obs, num_simulations, temperature=0, env=None, show=False): # Return probability distribution weighted by the number of simulations root = self.nodes['|'] n = root.n + 1 - n = (n / np.max(n)) ** (1 / (temperature + 1e-8)) - p = n / n.sum() + p = np.log(n / n.sum()) * (1 / (temperature + 1e-8)) v = root.q_sum_all / root.n_all return p, v diff --git a/handyrl/train.py b/handyrl/train.py index 8e6c7336..ddf142d5 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -184,7 +184,7 @@ def forward_prediction(model, hidden, batch, obs_mode): if k == 'policy': # gather turn player's policies o = o.view(*batch['turn_mask'].size()[:2], -1, o.size(-1)) - outputs[k] = o.mul(batch['turn_mask'].unsqueeze(-1)).sum(-2) - batch['action_mask'] + outputs[k] = o.mul(batch['turn_mask'].unsqueeze(-1)).sum(-2)# - batch['action_mask'] else: # mask valid target values and cumulative rewards outputs[k] = o.view(*batch['turn_mask'].size()[:2], -1).mul(batch['observation_mask']) @@ -207,8 +207,8 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba turn_advantages = total_advantages.mul(tmasks).sum(-1, keepdim=True) losses['p'] = (-log_selected_policies * turn_advantages).sum() - spolicies = batch['supervised_policy'] - losses['sp'] = (spolicies * (torch.clamp(spolicies, 1e-10, 1).log() - F.log_softmax(outputs['policy'], dim=-1))).sum(-1, keepdim=True).mul(tmasks).sum() + spolicies = batch['supervised_policy'] - batch['action_mask'] + losses['sp'] = (F.softmax(spolicies, -1) * (F.log_softmax(spolicies, -1) - F.log_softmax(outputs['policy'], -1))).sum(-1, keepdim=True).mul(tmasks).sum() if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 losses['sv'] = ((outputs['value'] - batch['outcome']) ** 2).mul(omasks).sum() / 2 From 4f031509dba1e7fe97c158d69ffffd8cbc9a2bbd Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 03:59:49 +0900 Subject: [PATCH 04/22] fix: enable to use MuZero in evaluation mode --- handyrl/environments/tictactoe.py | 4 ++++ handyrl/evaluation.py | 7 ++++--- handyrl/train.py | 3 +-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/handyrl/environments/tictactoe.py b/handyrl/environments/tictactoe.py index 1839efd1..a4c366b6 100755 --- a/handyrl/environments/tictactoe.py +++ b/handyrl/environments/tictactoe.py @@ -9,6 +9,7 @@ import numpy as np from ..environment import BaseEnvironment +from ..model import MuZero class Environment(BaseEnvironment): @@ -101,6 +102,9 @@ def action_length(self): def players(self): return [0, 1] + def net(self): + return MuZero + def observation(self, player=None): # input feature for neural nets turn_view = player is None or player == self.turn() diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index dc650ff2..3c0b6423 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -77,11 +77,11 @@ def plan(self, obs): def action(self, env, player, show=False): outputs = self.plan(env.observation(player)) actions = env.legal_actions() - p = outputs['policy'] + p_ = outputs['policy'] v = outputs.get('value', None) - mask = np.ones_like(p) + mask = np.ones_like(p_) mask[actions] = 0 - p -= mask * 1e32 + p = p_ - mask * 1e32 def softmax(x): x = np.exp(x - np.max(x, axis=-1)) @@ -89,6 +89,7 @@ def softmax(x): if show: view(env, player=player) + print('p_ = %s' % (softmax(p_) * 1000).astype(int)) print_outputs(env, softmax(p), v) if self.temperature == 0: diff --git a/handyrl/train.py b/handyrl/train.py index ddf142d5..76506e06 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -22,8 +22,7 @@ from .environment import prepare_env, make_env from .util import map_r, bimap_r, trimap_r, rotate, type_r from .model import to_torch, to_gpu_or_not, RandomModel -#from .model import SimpleConv2DModel as DefaultModel -from .model import MuZero as DefaultModel +from .model import SimpleConv2DModel as DefaultModel from .losses import compute_target from .connection import MultiProcessWorkers from .connection import accept_socket_connections From c7adee00d0517ea9c54e8a4ba2ea37c56f770a73 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 04:02:36 +0900 Subject: [PATCH 05/22] feature: consider legal action set in entropy loss --- handyrl/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index 76506e06..77223885 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -214,7 +214,7 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba if 'return' in outputs: losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum() - entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1)) + entropy = dist.Categorical(logits=outputs['policy'] - batch['action_mask']).entropy().mul(tmasks.sum(-1)) losses['ent'] = entropy.sum() teacher_weight = 1 From 6376e8330b8d1f088e1763b0d04af6635df881c4 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 04:29:26 +0900 Subject: [PATCH 06/22] feature: training after episode --- handyrl/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 77223885..fdf5f885 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -94,12 +94,12 @@ def make_batch(episodes, args): p = np.pad(p, [(0, pad_len), (0, 0)], 'constant', constant_values=0) v = np.concatenate([v, np.tile(oc, [pad_len, 1])]) sp = np.pad(sp, [(0, pad_len), (0, 0)], 'constant', constant_values=0) - act = np.pad(act, [(0, pad_len), (0, 0)], 'constant', constant_values=0) + act = np.concatenate([act, [[random.randrange(len(p[0]))] for _ in range(pad_len)]]) rew = np.pad(rew, [(0, pad_len), (0, 0)], 'constant', constant_values=0) ret = np.pad(ret, [(0, pad_len), (0, 0)], 'constant', constant_values=0) - emask = np.pad(emask, [(0, pad_len), (0, 0)], 'constant', constant_values=0) - tmask = np.pad(tmask, [(0, pad_len), (0, 0)], 'constant', constant_values=0) - omask = np.pad(omask, [(0, pad_len), (0, 0)], 'constant', constant_values=0) + emask = np.pad(emask, [(0, pad_len), (0, 0)], 'constant', constant_values=1) + tmask = np.pad(tmask, [(0, pad_len), (0, 0)], 'constant', constant_values=1) + omask = np.pad(omask, [(0, pad_len), (0, 0)], 'constant', constant_values=1) amask = np.pad(amask, [(0, pad_len), (0, 0)], 'constant', constant_values=1e32) progress = np.pad(progress, [(0, pad_len)], 'constant', constant_values=1) @@ -306,7 +306,8 @@ def select_episode(self): if random.random() < accept_rate: break ep = self.episodes[ep_idx] - turn_candidates = 1 + max(0, ep['steps'] - self.args['forward_steps']) # change start turn by sequence length + #turn_candidates = 1 + max(0, ep['steps'] - self.args['forward_steps']) # change start turn by sequence length + turn_candidates = ep['steps'] st = random.randrange(turn_candidates) ed = min(st + self.args['forward_steps'], ep['steps']) st_block = st // self.args['compress_steps'] From 256fd1a4c7d1aed52d4a22bd7da31860f5216d27 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 06:07:56 +0900 Subject: [PATCH 07/22] feature: upadte search.py --- handyrl/search.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/handyrl/search.py b/handyrl/search.py index aca8e5b5..a3aad1a7 100755 --- a/handyrl/search.py +++ b/handyrl/search.py @@ -81,8 +81,8 @@ def think(self, root_obs, num_simulations, temperature=1.0, env=None, show=False prev_time = tmp_time root, pv = self.nodes['|'], self.pv(env) print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s' - % (tmp_time, env.action2str(pv[0]), root.q_sum[pv[0]] / root.n[pv[0]], - root.n[pv[0]], root.n_all, ' '.join([env.action2str(a) for a in pv]))) + % (tmp_time, env.action2str(pv[0][0], pv[0][1]), root.q_sum[pv[0][0]] / root.n[pv[0][0]], + root.n[pv[0][0]], root.n_all, ' '.join([env.action2str(a, p) for a, p in pv]))) # Return probability distribution weighted by the number of simulations root = self.nodes['|'] @@ -91,14 +91,16 @@ def think(self, root_obs, num_simulations, temperature=1.0, env=None, show=False v = root.q_sum_all / root.n_all return p, v - def pv(self, env): + def pv(self, env_): # Return principal variation (action sequence which is considered as the best) - s, pv_seq = copy.deepcopy(env), [] + env = copy.deepcopy(env_) + pv_seq = [] while True: - key = '|' + path = list(zip(*pv_seq))[0] + key = '|' + ' '.join(map(str, path)) if key not in self.nodes or self.nodes[key].n.sum() == 0: break best_action = sorted([(a, self.nodes[key].n[a]) for a in env.legal_actions()], key=lambda x: -x[1])[0][0] - pv_seq.append(best_action) - s.play(best_action) + pv_seq.append((best_action, env.turn())) + env.play(best_action) return pv_seq \ No newline at end of file From 482fa60414709ec5a2215820a11a3fe72b3319ae Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 06:57:44 +0900 Subject: [PATCH 08/22] feature: TicTacToe from objective side --- handyrl/environments/tictactoe.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/handyrl/environments/tictactoe.py b/handyrl/environments/tictactoe.py index 1839efd1..624e2df5 100755 --- a/handyrl/environments/tictactoe.py +++ b/handyrl/environments/tictactoe.py @@ -27,10 +27,12 @@ def reset(self, args=None): self.record = [] def action2str(self, a, _=None): - return self.X[a // 3] + self.Y[a % 3] + pos = a % 9 + return self.X[pos // 3] + self.Y[pos % 3] - def str2action(self, s, _=None): - return self.X.find(s[0]) * 3 + self.Y.find(s[1]) + def str2action(self, s, player): + pos = self.X.find(s[0]) * 3 + self.Y.find(s[1]) + return pos + 9 * player def record_string(self): return ' '.join([self.action2str(a) for a in self.record]) @@ -47,10 +49,11 @@ def play(self, action): # action is integer (0 ~ 8) or string (sequence) if isinstance(action, str): for astr in action.split(): - self.play(self.str2action(astr)) + self.play(self.str2action(astr, self.turn())) return - x, y = action // 3, action % 3 + pos = action % 9 + x, y = pos // 3, pos % 3 self.board[x, y] = self.color # check winning condition @@ -92,23 +95,22 @@ def outcome(self): def legal_actions(self): # legal action list - return [a for a in range(3 * 3) if self.board[a // 3, a % 3] == 0] + player = self.turn() + return [pos + 9 * player for pos in range(3 * 3) if self.board[pos // 3, pos % 3] == 0] def action_length(self): # maximum size of policy (it determines output size of policy function) - return 3 * 3 + return 3 * 3 * 2 def players(self): return [0, 1] def observation(self, player=None): # input feature for neural nets - turn_view = player is None or player == self.turn() - color = self.color if turn_view else -self.color a = np.stack([ - np.ones_like(self.board) if turn_view else np.zeros_like(self.board), - self.board == color, - self.board == -color + np.ones_like(self.board) if self.turn() == 0 else np.zeros_like(self.board), + self.board == self.BLACK, + self.board == self.WHITE, ]).astype(np.float32) return a From 9e0224b0d92e64925306d41ea2a40fb7550b2d07 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 08:38:25 +0900 Subject: [PATCH 09/22] feature: turn-free MuZero implementation --- handyrl/model.py | 6 +++--- handyrl/search.py | 16 ++++++++++------ handyrl/train.py | 5 ++++- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/handyrl/model.py b/handyrl/model.py index b62ad517..5cdb7bbc 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -296,10 +296,10 @@ def inference(self, x): class Prediction(nn.Module): ''' Policy and value prediction from inner abstract state ''' - def __init__(self, internal_size, action_length): + def __init__(self, internal_size, action_length, player_count): super().__init__() self.head_p = Head(internal_size, 4, action_length) - self.head_v = Head(internal_size, 4, 1) + self.head_v = Head(internal_size, 4, player_count) def forward(self, rp): p = self.head_p(rp) @@ -344,7 +344,7 @@ def __init__(self, env, args={}): self.nets = nn.ModuleDict({ 'representation': self.Representation(self.input_size[0], layers, filters), - 'prediction': self.Prediction(internal_size, self.action_length), + 'prediction': self.Prediction(internal_size, self.action_length, len(env.players())), 'dynamics': self.Dynamics(internal_size, self.action_length, 2, layers), }) diff --git a/handyrl/search.py b/handyrl/search.py index a3aad1a7..a62b423b 100755 --- a/handyrl/search.py +++ b/handyrl/search.py @@ -18,7 +18,8 @@ class Node: '''Search result of one abstract (or root) state''' def __init__(self, p, v): self.p, self.v = p, v - self.n, self.q_sum = np.zeros_like(p), np.zeros_like(p) + self.n = np.zeros_like(p) + self.q_sum = np.zeros((*p.shape, 2)) self.n_all, self.q_sum_all = 1, v / 2 # prior def update(self, action, q_new): @@ -43,7 +44,7 @@ def search(self, rp, path): key = '|' + ' '.join(map(str, path)) if key not in self.nodes: p, v = self.model['prediction'].inference(rp) - p, v = softmax(p), v[0] + p, v = softmax(p), v self.nodes[key] = Node(p, v) return v @@ -52,18 +53,21 @@ def search(self, rp, path): p = node.p if len(path) == 0: # Add noise to policy on the root node - p = 0.75 * p + 0.25 * np.random.dirichlet([0.15] * len(p)) + p = 0.75 * p + 0.25 * np.random.dirichlet([0.15] * np.prod(p.shape)).reshape(*p.shape) # On the root node, we choose action only from legal actions p /= p.sum() + 1e-16 - n, q_sum = 1 + node.n, node.q_sum_all / node.n_all + node.q_sum - ucb = q_sum / n + 2.0 * np.sqrt(node.n_all) * p / n # PUCB formula + q_mean_all = node.q_sum_all.reshape(1, -1) / node.n_all + n, q_sum = 1 + node.n, q_mean_all + node.q_sum + adv = (q_sum / n.reshape(-1, 1) - q_mean_all).reshape(2, -1, 2) + adv = np.concatenate([adv[0, :, 0], adv[1, :, 1]]) + ucb = adv + 2.0 * np.sqrt(node.n_all) * p / n # PUCB formula best_action = np.argmax(ucb) # Search next state by recursively calling this function next_rp = self.model['dynamics'].inference(rp, np.array([best_action])) path.append(best_action) - q_new = -self.search(next_rp, path) # With the assumption of changing player by turn + q_new = self.search(next_rp, path) node.update(best_action, q_new) return q_new diff --git a/handyrl/train.py b/handyrl/train.py index fdf5f885..e68a53ef 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -46,6 +46,9 @@ def make_batch(episodes, args): obss, datum = [], [] + def replace_none(a, b): + return a if a is not None else b + for ep in episodes: # target player and turn index moments_ = sum([pickle.loads(bz2.decompress(ms)) for ms in ep['moment']], []) @@ -66,7 +69,7 @@ def make_batch(episodes, args): p = np.array([m['policy'] for m in moments]) sp = np.array([m['supervised_policy'] for m in moments]) v = np.array( - [[m['value'][player] or 0 for player in players] for m in moments], + [replace_none(m['value'][m['turn']], [0, 0]) for m in moments], dtype=np.float32 ).reshape(-1, len(players)) rew = np.array( From 5848c78debd74b9491abb0379fa55826d4ba785e Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 09:10:08 +0900 Subject: [PATCH 10/22] chore: change MuZero head size --- handyrl/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/model.py b/handyrl/model.py index 5cdb7bbc..c792f7ae 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -299,7 +299,7 @@ class Prediction(nn.Module): def __init__(self, internal_size, action_length, player_count): super().__init__() self.head_p = Head(internal_size, 4, action_length) - self.head_v = Head(internal_size, 4, player_count) + self.head_v = Head(internal_size, 2, player_count) def forward(self, rp): p = self.head_p(rp) From 2b78a361c375ae944708f7e3a24153b29de1c7ae Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 25 Jan 2021 11:22:20 +0900 Subject: [PATCH 11/22] fix: output multidimensional value --- handyrl/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 3c0b6423..bac07a06 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -54,7 +54,7 @@ def print_outputs(env, prob, v): if hasattr(env, 'print_outputs'): env.print_outputs(prob, v) else: - print('v = %f' % v) + print('v = %s' % v) print('p = %s' % (prob * 1000).astype(int)) From 6a620290820c298af9a8b730a80b0c0963fbf86d Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 1 Feb 2021 01:05:10 +0900 Subject: [PATCH 12/22] chore: updates planning losses notation --- handyrl/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 50626528..48ccf8cb 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -198,8 +198,8 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba turn_advantages = total_advantages.mul(tmasks).sum(2, keepdim=True) losses['p'] = (-log_selected_policies * turn_advantages).sum() - spolicies = batch['policy'] - batch['action_mask'] - losses['pp'] = (F.softmax(spolicies, -1) * (F.log_softmax(spolicies, -1) - F.log_softmax(outputs['policy'], -1))).sum(-1, keepdim=True).mul(tmasks).sum() + target_policies = batch['policy'] - batch['action_mask'] + losses['pp'] = (F.softmax(target_policies, -1) * (F.log_softmax(target_policies, -1) - F.log_softmax(outputs['policy'], -1))).sum(-1, keepdim=True).mul(tmasks).sum() if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 losses['pv'] = ((outputs['value'] - batch['outcome']) ** 2).mul(omasks).sum() / 2 @@ -209,10 +209,10 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba entropy = dist.Categorical(logits=outputs['policy'] - batch['action_mask']).entropy().mul(tmasks.sum(-1)) losses['ent'] = entropy.sum() - teacher_weight = 1 + planning_weight = 1 reinforce_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) planning_loss = losses['pp'] + losses.get('pv', 0) - base_loss = (1 - teacher_weight) * reinforce_loss + teacher_weight * planning_loss + base_loss = (1 - planning_weight) * reinforce_loss + planning_weight * planning_loss entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] losses['total'] = base_loss + entropy_loss From c40fba9738dfff4a108f813cd0da1dbbf201cc98 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 1 Feb 2021 02:18:35 +0900 Subject: [PATCH 13/22] chore: apply legal mask for policy in generator --- handyrl/generation.py | 2 +- handyrl/train.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/handyrl/generation.py b/handyrl/generation.py index 5082aeb1..f6629f02 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -53,7 +53,7 @@ def generate(self, models, args): legal_actions = self.env.legal_actions() action_mask = np.ones_like(p_) * 1e32 action_mask[legal_actions] = 0 - p = p_ / temperature + p = p_ / temperature - action_mask moment['policy'][player] = p moment['action_mask'][player] = action_mask diff --git a/handyrl/train.py b/handyrl/train.py index 48ccf8cb..8cb9236d 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -198,8 +198,7 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba turn_advantages = total_advantages.mul(tmasks).sum(2, keepdim=True) losses['p'] = (-log_selected_policies * turn_advantages).sum() - target_policies = batch['policy'] - batch['action_mask'] - losses['pp'] = (F.softmax(target_policies, -1) * (F.log_softmax(target_policies, -1) - F.log_softmax(outputs['policy'], -1))).sum(-1, keepdim=True).mul(tmasks).sum() + losses['pp'] = (F.softmax(batch['policy'], -1) * (F.log_softmax(batch['policy'], -1) - F.log_softmax(outputs['policy'], -1))).sum(-1, keepdim=True).mul(tmasks).sum() if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 losses['pv'] = ((outputs['value'] - batch['outcome']) ** 2).mul(omasks).sum() / 2 From 96d48b0ebe93b07c33aef0c41271a76b201c052f Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 1 Feb 2021 03:05:19 +0900 Subject: [PATCH 14/22] chore: update muzero model and search notation --- config.yaml | 3 +++ handyrl/model.py | 15 ++++++++------- handyrl/search.py | 27 ++++++++++++++------------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/config.yaml b/config.yaml index 2223d73e..c3e17556 100755 --- a/config.yaml +++ b/config.yaml @@ -26,6 +26,9 @@ train_args: policy_target: 'TD' # 'UGPO' 'VTRACE' 'TD' 'MC' value_target: 'TD' # 'VTRACE' 'TD' 'MC' policy_decay: 0.9 + planning: + root_noise_alpha: 0.15 + root_noise_coef: 0.25 seed: 0 restart_epoch: 0 diff --git a/handyrl/model.py b/handyrl/model.py index e54609b5..69349d24 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -281,7 +281,7 @@ def forward(self, x): def inference(self, x): self.eval() with torch.no_grad(): - rp = self(torch.from_numpy(x).unsqueeze(0)) + rp = self(to_torch(x, unsqueeze=0)) return rp.cpu().numpy().squeeze(0) class Prediction(nn.Module): @@ -299,17 +299,17 @@ def forward(self, rp): def inference(self, rp): self.eval() with torch.no_grad(): - p, v = self(torch.from_numpy(rp).unsqueeze(0)) + p, v = self(to_torch(rp, unsqueeze=0)) return p.cpu().numpy().squeeze(0), v.cpu().numpy().squeeze(0) class Dynamics(nn.Module): '''Abstract state transition''' - def __init__(self, rp_shape, action_length, action_filters, layers): + def __init__(self, rp_shape, layers, action_length, action_filters): super().__init__() self.action_shape = action_filters, rp_shape[1], rp_shape[2] filters = rp_shape[0] self.action_embedding = nn.Embedding(action_length, embedding_dim=np.prod(self.action_shape)) - self.layer0 = Conv(rp_shape[0] + self.action_shape[0], filters, 3, bn=True) + self.layer0 = Conv(filters + action_filters, filters, 3, bn=True) self.blocks = nn.ModuleList([WideResidualBlock(filters, 3, bn=True) for _ in range(layers)]) def forward(self, rp, a): @@ -323,7 +323,7 @@ def forward(self, rp, a): def inference(self, rp, a): self.eval() with torch.no_grad(): - rp = self(torch.from_numpy(rp).unsqueeze(0), torch.from_numpy(a).unsqueeze(0)) + rp = self(to_torch(rp, unsqueeze=0), to_torch(a, unsqueeze=0)) return rp.cpu().numpy().squeeze(0) def __init__(self, env, args={}): @@ -331,11 +331,12 @@ def __init__(self, env, args={}): self.input_size = env.observation().shape layers, filters = args.get('layers', 3), args.get('filters', 32) internal_size = (filters, *self.input_size[1:]) + self.planning_args = args['planning'] self.nets = nn.ModuleDict({ 'representation': self.Representation(self.input_size[0], layers, filters), 'prediction': self.Prediction(internal_size, self.action_length, len(env.players())), - 'dynamics': self.Dynamics(internal_size, self.action_length, 2, layers), + 'dynamics': self.Dynamics(internal_size, layers, self.action_length, 2), }) def init_hidden(self, batch_size=None): @@ -355,6 +356,6 @@ def forward(self, x, hidden, action=None): return outputs def inference(self, x, hidden=None, num_simulations=30): - tree = MonteCarloTree(self.nets, {}) + tree = MonteCarloTree(self.nets, self.planning_args) p, v = tree.think(x, num_simulations) return {'policy': p, 'value': v} diff --git a/handyrl/search.py b/handyrl/search.py index a62b423b..0ec75af5 100755 --- a/handyrl/search.py +++ b/handyrl/search.py @@ -20,7 +20,7 @@ def __init__(self, p, v): self.p, self.v = p, v self.n = np.zeros_like(p) self.q_sum = np.zeros((*p.shape, 2)) - self.n_all, self.q_sum_all = 1, v / 2 # prior + self.n_all, self.q_sum_all = 1, v / 2 # prior def update(self, action, q_new): # Update @@ -36,8 +36,8 @@ class MonteCarloTree: '''Monte Carlo Tree Search''' def __init__(self, model, args): self.model = model + self.args = args self.nodes = {} - self.args = {} def search(self, rp, path): # Return predicted value from new state @@ -48,12 +48,13 @@ def search(self, rp, path): self.nodes[key] = Node(p, v) return v - # State transition by an action selected from bandit + # Choose action with bandit node = self.nodes[key] p = node.p if len(path) == 0: # Add noise to policy on the root node - p = 0.75 * p + 0.25 * np.random.dirichlet([0.15] * np.prod(p.shape)).reshape(*p.shape) + noise = np.random.dirichlet([self.args['root_noise_alpha']] * np.prod(p.shape)).reshape(*p.shape) + p = (1 - self.args['root_noise_coef']) * p + self.args['root_noise_coef'] * noise # On the root node, we choose action only from legal actions p /= p.sum() + 1e-16 @@ -62,17 +63,17 @@ def search(self, rp, path): adv = (q_sum / n.reshape(-1, 1) - q_mean_all).reshape(2, -1, 2) adv = np.concatenate([adv[0, :, 0], adv[1, :, 1]]) ucb = adv + 2.0 * np.sqrt(node.n_all) * p / n # PUCB formula - best_action = np.argmax(ucb) + selected_action = np.argmax(ucb) # Search next state by recursively calling this function - next_rp = self.model['dynamics'].inference(rp, np.array([best_action])) - path.append(best_action) + next_rp = self.model['dynamics'].inference(rp, np.array([selected_action])) + path.append(selected_action) q_new = self.search(next_rp, path) - node.update(best_action, q_new) + node.update(selected_action, q_new) return q_new - def think(self, root_obs, num_simulations, temperature=1.0, env=None, show=False): + def think(self, root_obs, num_simulations, env=None, show=False): # End point of MCTS start, prev_time = time.time(), 0 for _ in range(num_simulations): @@ -90,9 +91,9 @@ def think(self, root_obs, num_simulations, temperature=1.0, env=None, show=False # Return probability distribution weighted by the number of simulations root = self.nodes['|'] - n = root.n + 1 - p = np.log(n / n.sum()) * (1 / (temperature + 1e-8)) - v = root.q_sum_all / root.n_all + n = root.n + 0.1 + p = np.log(n / n.sum()) + v = (root.q_sum * p.reshape(-1, 1)).sum(0) return p, v def pv(self, env_): @@ -107,4 +108,4 @@ def pv(self, env_): best_action = sorted([(a, self.nodes[key].n[a]) for a in env.legal_actions()], key=lambda x: -x[1])[0][0] pv_seq.append((best_action, env.turn())) env.play(best_action) - return pv_seq \ No newline at end of file + return pv_seq From 537c204b18613d926059376c0eb0c6cdbf98d974 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 1 Feb 2021 03:21:48 +0900 Subject: [PATCH 15/22] fix: make_batch() for MuZero --- handyrl/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 8cb9236d..4eec1bd7 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -73,7 +73,7 @@ def replace_none(a, b): oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask - amask = np.array([[m['action_mask'][m['turn']]] for m in moments]) + amask = np.array([[m['action_mask'][m['turn']]] for m in moments], dtype=np.int64) tmask = np.array([[[m['policy'][player] is not None] for player in players] for m in moments], dtype=np.float32) omask = np.array([[[m['value'][player] is not None] for player in players] for m in moments], dtype=np.float32) @@ -87,12 +87,12 @@ def replace_none(a, b): obs = map_r(obs, lambda o: np.pad(o, [(0, pad_len)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0)) p = np.pad(p, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) v = np.concatenate([v, np.tile(oc, [pad_len, 1, 1])]) - act = np.pad(act, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) + act = np.concatenate([act, [[[random.randrange(len(p[0]))]] for _ in range(pad_len)]]) rew = np.pad(rew, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) ret = np.pad(ret, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - emask = np.pad(emask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - tmask = np.pad(tmask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - omask = np.pad(omask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) + emask = np.pad(emask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=1) + tmask = np.pad(tmask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=1) + omask = np.pad(omask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=1) amask = np.pad(amask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=1e32) progress = np.pad(progress, [(0, pad_len), (0, 0)], 'constant', constant_values=1) From 6cbe4c150519b30738f4aede0f461d9c6dd21a28 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 1 Feb 2021 03:27:10 +0900 Subject: [PATCH 16/22] fix: fix of make_batch() for MuZero --- handyrl/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 4eec1bd7..d81499d8 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -73,11 +73,11 @@ def replace_none(a, b): oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask - amask = np.array([[m['action_mask'][m['turn']]] for m in moments], dtype=np.int64) + amask = np.array([[m['action_mask'][m['turn']]] for m in moments]) tmask = np.array([[[m['policy'][player] is not None] for player in players] for m in moments], dtype=np.float32) omask = np.array([[[m['value'][player] is not None] for player in players] for m in moments], dtype=np.float32) - act = np.array([[m['action']] for m in moments])[..., np.newaxis] + act = np.array([[m['action']] for m in moments], dtype=np.int64)[..., np.newaxis] progress = np.arange(ep['start'], ep['end'], dtype=np.float32)[..., np.newaxis] / ep['total'] From cd57e86226ffc56f64fa7a5350e411fd9f6394f5 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 14 Feb 2021 19:06:51 +0900 Subject: [PATCH 17/22] feature: player count in BaseModel --- handyrl/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/handyrl/model.py b/handyrl/model.py index dd51d83c..a6a793a5 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -221,6 +221,7 @@ class BaseModel(nn.Module): def __init__(self, env, args=None, action_length=None): super().__init__() self.action_length = env.action_length() if action_length is None else action_length + self.num_players = len(env.players()) def init_hidden(self, batch_size=None): return None From 42270fd98b6f3fc5ed9ca3adefadb425de7481d0 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 9 Mar 2021 22:44:15 +0900 Subject: [PATCH 18/22] fix: remove parallel tic-tac-toe from tests --- handyrl/envs/tictactoe.py | 2 +- tests/test_environment.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index e432c4bd..a6ec2d04 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -73,7 +73,7 @@ def update(self, info, reset): if reset: self.reset() else: - action = self.str2action(info) + action = self.str2action(info, self.turn()) self.play(action) def turn(self): diff --git a/tests/test_environment.py b/tests/test_environment.py index c137db3d..f8a1970d 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -7,7 +7,7 @@ ENVS = [ 'tictactoe', 'geister', - 'parallel_tictactoe', + #'parallel_tictactoe', 'kaggle.hungry_geese', ] From 753d83451eb0aa49e13ea7333c58c90da311e02e Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 9 Mar 2021 23:28:52 +0900 Subject: [PATCH 19/22] chore: all actions = num of players x action length --- handyrl/envs/tictactoe.py | 2 +- handyrl/model.py | 19 +++++++++++-------- handyrl/search.py | 4 ++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index a6ec2d04..b9bf64a8 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -99,7 +99,7 @@ def legal_actions(self, _=None): def action_length(self): # maximum size of policy (it determines output size of policy function) - return 3 * 3 * 2 + return 3 * 3 def players(self): return [0, 1] diff --git a/handyrl/model.py b/handyrl/model.py index 82d003f9..901f9c16 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -239,7 +239,10 @@ def inference(self, x, hidden, **kwargs): class RandomModel(BaseModel): def inference(self, x=None, hidden=None): - return {'policy': np.zeros(self.action_length, dtype=np.float32), 'value': np.zeros(2, dtype=np.float32)} + return { + 'policy': np.zeros(self.num_players * self.action_length, dtype=np.float32), + 'value': np.zeros(self.num_players, dtype=np.float32) + } class SimpleConv2dModel(BaseModel): @@ -287,10 +290,10 @@ def inference(self, x): class Prediction(nn.Module): ''' Policy and value prediction from inner abstract state ''' - def __init__(self, internal_size, action_length, player_count): + def __init__(self, internal_size, num_players, action_length): super().__init__() - self.head_p = Head(internal_size, 4, action_length) - self.head_v = Head(internal_size, 2, player_count) + self.head_p = Head(internal_size, 4, num_players * action_length) + self.head_v = Head(internal_size, 2, num_players) def forward(self, rp): p = self.head_p(rp) @@ -305,11 +308,11 @@ def inference(self, rp): class Dynamics(nn.Module): '''Abstract state transition''' - def __init__(self, rp_shape, layers, action_length, action_filters): + def __init__(self, rp_shape, layers, num_players, action_length, action_filters): super().__init__() self.action_shape = action_filters, rp_shape[1], rp_shape[2] filters = rp_shape[0] - self.action_embedding = nn.Embedding(action_length, embedding_dim=np.prod(self.action_shape)) + self.action_embedding = nn.Embedding(num_players * action_length, embedding_dim=np.prod(self.action_shape)) self.layer0 = Conv(filters + action_filters, filters, 3, bn=True) self.blocks = nn.ModuleList([WideResidualBlock(filters, 3, bn=True) for _ in range(layers)]) @@ -336,8 +339,8 @@ def __init__(self, env, args={}): self.nets = nn.ModuleDict({ 'representation': self.Representation(self.input_size[0], layers, filters), - 'prediction': self.Prediction(internal_size, self.action_length, len(env.players())), - 'dynamics': self.Dynamics(internal_size, layers, self.action_length, 2), + 'prediction': self.Prediction(internal_size, self.num_players, self.action_length), + 'dynamics': self.Dynamics(internal_size, layers, self.num_players, self.action_length, 2), }) def init_hidden(self, batch_size=None): diff --git a/handyrl/search.py b/handyrl/search.py index 0ec75af5..b8cffdd8 100755 --- a/handyrl/search.py +++ b/handyrl/search.py @@ -19,7 +19,7 @@ class Node: def __init__(self, p, v): self.p, self.v = p, v self.n = np.zeros_like(p) - self.q_sum = np.zeros((*p.shape, 2)) + self.q_sum = np.zeros((*p.shape, v.shape[-1]), dtype=np.float32) self.n_all, self.q_sum_all = 1, v / 2 # prior def update(self, action, q_new): @@ -60,7 +60,7 @@ def search(self, rp, path): q_mean_all = node.q_sum_all.reshape(1, -1) / node.n_all n, q_sum = 1 + node.n, q_mean_all + node.q_sum - adv = (q_sum / n.reshape(-1, 1) - q_mean_all).reshape(2, -1, 2) + adv = (q_sum / n.reshape(-1, 1) - q_mean_all).reshape(q_sum.shape[-1], -1, q_sum.shape[-1]) adv = np.concatenate([adv[0, :, 0], adv[1, :, 1]]) ucb = adv + 2.0 * np.sqrt(node.n_all) * p / n # PUCB formula selected_action = np.argmax(ucb) From fc76e52955318c51a05e7c19218b8ed77e9bc18b Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 19 Nov 2021 05:16:53 +0900 Subject: [PATCH 20/22] feature: use F.kl_loss() to compute KL divergence --- handyrl/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index 2aa3c70c..d96b08d1 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -204,7 +204,7 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba turn_advantages = total_advantages.mul(tmasks).sum(2, keepdim=True) losses['p'] = (-log_selected_policies * turn_advantages).sum() - losses['pp'] = (F.softmax(batch['policy'], -1) * (F.log_softmax(batch['policy'], -1) - F.log_softmax(outputs['policy'], -1))).sum(-1, keepdim=True).mul(tmasks).sum() + losses['pp'] = F.kl_div(F.log_softmax(outputs['policy'], -1), F.softmax(batch['policy'], -1), reduction='none').sum(-1, keepdim=True).mul(tmasks).sum() if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 losses['pv'] = ((outputs['value'] - batch['outcome']) ** 2).mul(omasks).sum() / 2 From aca6f6f6121bf2c14c59b8fe45156f6bd762a0d0 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 22 Nov 2021 10:18:25 +0900 Subject: [PATCH 21/22] fix: outputted value of search function --- handyrl/search.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/handyrl/search.py b/handyrl/search.py index b8cffdd8..b80823ff 100755 --- a/handyrl/search.py +++ b/handyrl/search.py @@ -91,10 +91,10 @@ def think(self, root_obs, num_simulations, env=None, show=False): # Return probability distribution weighted by the number of simulations root = self.nodes['|'] - n = root.n + 0.1 - p = np.log(n / n.sum()) - v = (root.q_sum * p.reshape(-1, 1)).sum(0) - return p, v + n = root.n + 1e-4 + p = n / n.sum() + v = ((root.q_sum / (root.n.reshape(-1, 1) + 1e-4)) * p.reshape(-1, 1)).sum(0) + return np.log(p), v def pv(self, env_): # Return principal variation (action sequence which is considered as the best) From ccdd3adc743a8fdac06c2d249ece56e59dc733ce Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 24 Nov 2021 01:19:09 +0900 Subject: [PATCH 22/22] feature: residual net for MuZero --- handyrl/envs/tictactoe.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index 41a8f8b5..05d1c514 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -71,16 +71,26 @@ def forward(self, x, hidden=None): return {'policy': h_p, 'value': torch.tanh(h_v)} +class ResidualBlock(nn.Module): + def __init__(self, filters0, filters1): + super().__init__() + self.conv = Conv(filters0, filters1, 3, bn=True) + + def forward(self, x): + h = self.conv(x) + return F.relu_(x + h) + + class MuZero(nn.Module): class Representation(nn.Module): ''' Conversion from observation to inner abstract state ''' def __init__(self, input_dim, layers, filters): super().__init__() self.layer0 = Conv(input_dim, filters, 3, bn=True) - self.blocks = nn.ModuleList([Conv(filters, filters, 3, bn=True) for _ in range(layers)]) + self.blocks = nn.ModuleList([ResidualBlock(filters, filters) for _ in range(layers)]) def forward(self, x): - h = F.relu(self.layer0(x)) + h = F.relu_(self.layer0(x)) for block in self.blocks: h = block(h) return h @@ -117,12 +127,12 @@ def __init__(self, rp_shape, layers, num_players, action_length, action_filters) filters = rp_shape[0] self.action_embedding = nn.Embedding(num_players * action_length, embedding_dim=np.prod(self.action_shape)) self.layer0 = Conv(filters + action_filters, filters, 3, bn=True) - self.blocks = nn.ModuleList([Conv(filters, filters, 3, bn=True) for _ in range(layers)]) + self.blocks = nn.ModuleList([ResidualBlock(filters, filters) for _ in range(layers)]) def forward(self, rp, a): arp = self.action_embedding(a).view(-1, *self.action_shape) h = torch.cat([rp, arp], dim=1) - h = self.layer0(h) + h = F.relu_(self.layer0(h)) for block in self.blocks: h = block(h) return h @@ -140,6 +150,7 @@ def __init__(self, env, obs, action_length): self.input_size = obs.shape layers, filters = 3, 32 + action_filters = 4 internal_size = (filters, *self.input_size[1:]) self.planning_args = { 'root_noise_alpha': 0.15, @@ -149,7 +160,7 @@ def __init__(self, env, obs, action_length): self.nets = nn.ModuleDict({ 'representation': self.Representation(self.input_size[0], layers, filters), 'prediction': self.Prediction(internal_size, self.num_players, self.action_length), - 'dynamics': self.Dynamics(internal_size, layers, self.num_players, self.action_length, 2), + 'dynamics': self.Dynamics(internal_size, layers, self.num_players, self.action_length, action_filters), }) def init_hidden(self, batch_size=None):