diff --git a/config.yaml b/config.yaml index 2bf65d85..9864c116 100755 --- a/config.yaml +++ b/config.yaml @@ -26,6 +26,10 @@ train_args: lambda: 0.7 policy_target: 'TD' # 'UPGO' 'VTRACE' 'TD' 'MC' value_target: 'TD' # 'VTRACE' 'TD' 'MC' + policy_decay: 0.9 + planning: + root_noise_alpha: 0.15 + root_noise_coef: 0.25 eval: opponent: ['random'] seed: 0 diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index 2c27809c..55a8f3ab 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -12,6 +12,8 @@ import torch.nn.functional as F from ..environment import BaseEnvironment +from ..search import MonteCarloTree +from ..model import to_torch class Conv(nn.Module): @@ -69,6 +71,121 @@ def forward(self, x, hidden=None): return {'policy': h_p, 'value': torch.tanh(h_v)} +class ResidualBlock(nn.Module): + def __init__(self, filters0, filters1): + super().__init__() + self.conv = Conv(filters0, filters1, 3, bn=True) + + def forward(self, x): + h = self.conv(x) + return F.relu_(x + h) + + +class MuZero(nn.Module): + class Representation(nn.Module): + ''' Conversion from observation to inner abstract state ''' + def __init__(self, input_dim, layers, filters): + super().__init__() + self.layer0 = Conv(input_dim, filters, 3, bn=True) + self.blocks = nn.ModuleList([ResidualBlock(filters, filters) for _ in range(layers)]) + + def forward(self, x): + h = F.relu_(self.layer0(x)) + for block in self.blocks: + h = block(h) + return h + + def inference(self, x): + self.eval() + with torch.no_grad(): + rp = self(to_torch(x).unsqueeze(0)) + return rp.cpu().numpy().squeeze(0) + + class Prediction(nn.Module): + ''' Policy and value prediction from inner abstract state ''' + def __init__(self, internal_size, num_players, action_length): + super().__init__() + self.head_p = Head(internal_size, 4, num_players * action_length) + self.head_v = Head(internal_size, 2, num_players) + + def forward(self, rp): + p = self.head_p(rp) + v = self.head_v(rp) + return p, torch.tanh(v) + + def inference(self, rp): + self.eval() + with torch.no_grad(): + p, v = self(to_torch(rp).unsqueeze(0)) + return p.cpu().numpy().squeeze(0), v.cpu().numpy().squeeze(0) + + class Dynamics(nn.Module): + '''Abstract state transition''' + def __init__(self, rp_shape, layers, num_players, action_length, action_filters): + super().__init__() + self.action_shape = action_filters, rp_shape[1], rp_shape[2] + filters = rp_shape[0] + self.action_embedding = nn.Embedding(num_players * action_length, embedding_dim=np.prod(self.action_shape)) + self.layer0 = Conv(filters + action_filters, filters, 3, bn=True) + self.blocks = nn.ModuleList([ResidualBlock(filters, filters) for _ in range(layers)]) + + def forward(self, rp, a): + arp = self.action_embedding(a).view(-1, *self.action_shape) + h = torch.cat([rp, arp], dim=1) + h = F.relu_(self.layer0(h)) + for block in self.blocks: + h = block(h) + return h + + def inference(self, rp, a): + self.eval() + with torch.no_grad(): + rp = self(to_torch(rp).unsqueeze(0), to_torch(a).unsqueeze(0)) + return rp.cpu().numpy().squeeze(0) + + def __init__(self, env, obs, action_length): + super().__init__() + self.num_players = len(env.players()) + self.action_length = action_length + self.input_size = obs.shape + + layers, filters = 3, 32 + action_filters = 4 + internal_size = (filters, *self.input_size[1:]) + self.planning_args = { + 'root_noise_alpha': 0.15, + 'root_noise_coef': 0.25, + } + + self.nets = nn.ModuleDict({ + 'representation': self.Representation(self.input_size[0], layers, filters), + 'prediction': self.Prediction(internal_size, self.num_players, self.action_length), + 'dynamics': self.Dynamics(internal_size, layers, self.num_players, self.action_length, action_filters), + }) + + def init_hidden(self, batch_size=None): + return {} + + def forward(self, x, hidden, action=None): + if 'representation' not in hidden: + rp = self.nets['representation'](x) + else: + rp = hidden['representation'] + p, v = self.nets['prediction'](rp) + outputs = {'policy': p, 'value': v} + + if action is not None: + next_rp = self.nets['dynamics'](rp, action) + outputs['hidden'] = {'representation': next_rp} + return outputs + + def inference(self, x, hidden=None, num_simulations=30): + tree = MonteCarloTree(self.nets, self.planning_args) + p, v = tree.think(x, num_simulations) + return {'policy': p, 'value': v} + + + class Environment(BaseEnvironment): X, Y = 'ABC', '123' BLACK, WHITE = 1, -1 @@ -85,10 +202,12 @@ def reset(self, args=None): self.record = [] def action2str(self, a, _=None): - return self.X[a // 3] + self.Y[a % 3] + pos = a % 9 + return self.X[pos // 3] + self.Y[pos % 3] - def str2action(self, s, _=None): - return self.X.find(s[0]) * 3 + self.Y.find(s[1]) + def str2action(self, s, player): + pos = self.X.find(s[0]) * 3 + self.Y.find(s[1]) + return pos + 9 * player def record_string(self): return ' '.join([self.action2str(a) for a in self.record]) @@ -103,7 +222,8 @@ def __str__(self): def play(self, action, _=None): # state transition function # action is integer (0 ~ 8) - x, y = action // 3, action % 3 + pos = action % 9 + x, y = pos // 3, pos % 3 self.board[x, y] = self.color # check winning condition @@ -127,7 +247,7 @@ def update(self, info, reset): if reset: self.reset() else: - action = self.str2action(info) + action = self.str2action(info, self.turn()) self.play(action) def turn(self): @@ -148,22 +268,22 @@ def outcome(self): def legal_actions(self, _=None): # legal action list - return [a for a in range(3 * 3) if self.board[a // 3, a % 3] == 0] + player = self.turn() + return [pos + 9 * player for pos in range(3 * 3) if self.board[pos // 3, pos % 3] == 0] def players(self): return [0, 1] def net(self): - return SimpleConv2dModel() + obs = self.observation(self.players()[0]) + return MuZero(self, obs, 9) def observation(self, player=None): # input feature for neural nets - turn_view = player is None or player == self.turn() - color = self.color if turn_view else -self.color a = np.stack([ - np.ones_like(self.board) if turn_view else np.zeros_like(self.board), - self.board == color, - self.board == -color + np.ones_like(self.board) if self.turn() == 0 else np.zeros_like(self.board), + self.board == self.BLACK, + self.board == self.WHITE, ]).astype(np.float32) return a diff --git a/handyrl/generation.py b/handyrl/generation.py index 8bca1c98..b91fe94f 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -29,8 +29,9 @@ def generate(self, models, args): return None while not self.env.terminal(): - moment_keys = ['observation', 'selected_prob', 'action_mask', 'action', 'value', 'reward', 'return'] + moment_keys = ['observation', 'policy', 'selected_prob', 'action_mask', 'action', 'value', 'reward', 'return'] moment = {key: {p: None for p in self.env.players()} for key in moment_keys} + temperature = self.args['policy_decay'] ** len(moments) turn_players = self.env.turns() observers = self.env.observers() @@ -54,9 +55,10 @@ def generate(self, models, args): legal_actions = self.env.legal_actions(player) action_mask = np.ones_like(p_) * 1e32 action_mask[legal_actions] = 0 - p = softmax(p_ - action_mask) + p = softmax(p_ / temperature - action_mask) action = random.choices(legal_actions, weights=p[legal_actions])[0] + moment['policy'][player] = p moment['selected_prob'][player] = p[action] moment['action_mask'][player] = action_mask moment['action'][player] = action diff --git a/handyrl/model.py b/handyrl/model.py index d59bde82..ad94f0f7 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from .util import map_r +from .search import MonteCarloTree def to_torch(x): diff --git a/handyrl/search.py b/handyrl/search.py new file mode 100755 index 00000000..b80823ff --- /dev/null +++ b/handyrl/search.py @@ -0,0 +1,111 @@ +# Copyright (c) 2020 DeNA Co., Ltd. +# Licensed under The MIT License [see LICENSE for details] + +# tree search + +import copy +import time + +import numpy as np + + +def softmax(x): + x = np.exp(x - np.max(x, axis=-1)) + return x / x.sum(axis=-1) + + +class Node: + '''Search result of one abstract (or root) state''' + def __init__(self, p, v): + self.p, self.v = p, v + self.n = np.zeros_like(p) + self.q_sum = np.zeros((*p.shape, v.shape[-1]), dtype=np.float32) + self.n_all, self.q_sum_all = 1, v / 2 # prior + + def update(self, action, q_new): + # Update + self.n[action] += 1 + self.q_sum[action] += q_new + + # Update overall stats + self.n_all += 1 + self.q_sum_all += q_new + + +class MonteCarloTree: + '''Monte Carlo Tree Search''' + def __init__(self, model, args): + self.model = model + self.args = args + self.nodes = {} + + def search(self, rp, path): + # Return predicted value from new state + key = '|' + ' '.join(map(str, path)) + if key not in self.nodes: + p, v = self.model['prediction'].inference(rp) + p, v = softmax(p), v + self.nodes[key] = Node(p, v) + return v + + # Choose action with bandit + node = self.nodes[key] + p = node.p + if len(path) == 0: + # Add noise to policy on the root node + noise = np.random.dirichlet([self.args['root_noise_alpha']] * np.prod(p.shape)).reshape(*p.shape) + p = (1 - self.args['root_noise_coef']) * p + self.args['root_noise_coef'] * noise + # On the root node, we choose action only from legal actions + p /= p.sum() + 1e-16 + + q_mean_all = node.q_sum_all.reshape(1, -1) / node.n_all + n, q_sum = 1 + node.n, q_mean_all + node.q_sum + adv = (q_sum / n.reshape(-1, 1) - q_mean_all).reshape(q_sum.shape[-1], -1, q_sum.shape[-1]) + adv = np.concatenate([adv[0, :, 0], adv[1, :, 1]]) + ucb = adv + 2.0 * np.sqrt(node.n_all) * p / n # PUCB formula + selected_action = np.argmax(ucb) + + # Search next state by recursively calling this function + next_rp = self.model['dynamics'].inference(rp, np.array([selected_action])) + path.append(selected_action) + q_new = self.search(next_rp, path) + node.update(selected_action, q_new) + + return q_new + + def think(self, root_obs, num_simulations, env=None, show=False): + # End point of MCTS + start, prev_time = time.time(), 0 + for _ in range(num_simulations): + self.search(self.model['representation'].inference(root_obs), []) + + # Display search result on every second + if show: + tmp_time = time.time() - start + if int(tmp_time) > int(prev_time): + prev_time = tmp_time + root, pv = self.nodes['|'], self.pv(env) + print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s' + % (tmp_time, env.action2str(pv[0][0], pv[0][1]), root.q_sum[pv[0][0]] / root.n[pv[0][0]], + root.n[pv[0][0]], root.n_all, ' '.join([env.action2str(a, p) for a, p in pv]))) + + # Return probability distribution weighted by the number of simulations + root = self.nodes['|'] + n = root.n + 1e-4 + p = n / n.sum() + v = ((root.q_sum / (root.n.reshape(-1, 1) + 1e-4)) * p.reshape(-1, 1)).sum(0) + return np.log(p), v + + def pv(self, env_): + # Return principal variation (action sequence which is considered as the best) + env = copy.deepcopy(env_) + pv_seq = [] + while True: + path = list(zip(*pv_seq))[0] + key = '|' + ' '.join(map(str, path)) + if key not in self.nodes or self.nodes[key].n.sum() == 0: + break + best_action = sorted([(a, self.nodes[key].n[a]) for a in env.legal_actions()], key=lambda x: -x[1])[0][0] + pv_seq.append((best_action, env.turn())) + env.play(best_action) + return pv_seq diff --git a/handyrl/train.py b/handyrl/train.py index fe8ac0d4..77380f38 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -64,11 +64,13 @@ def replace_none(a, b): # data that is changed by training configuration if args['turn_based_training'] and not args['observation']: obs = [[m['observation'][m['turn'][0]]] for m in moments] + p = np.array([[m['policy'][m['turn'][0]]] for m in moments]) prob = np.array([[[m['selected_prob'][m['turn'][0]]]] for m in moments]) act = np.array([[m['action'][m['turn'][0]]] for m in moments], dtype=np.int64)[..., np.newaxis] amask = np.array([[m['action_mask'][m['turn'][0]]] for m in moments]) else: obs = [[replace_none(m['observation'][player], obs_zeros) for player in players] for m in moments] + p = np.array([[replace_none(m['policy'][player], amask_zeros) for player in players] for m in moments]) prob = np.array([[[replace_none(m['selected_prob'][player], 1.0)] for player in players] for m in moments]) act = np.array([[replace_none(m['action'][player], 0) for player in players] for m in moments], dtype=np.int64)[..., np.newaxis] amask = np.array([[replace_none(m['action_mask'][player], amask_zeros + 1e32) for player in players] for m in moments]) @@ -78,7 +80,7 @@ def replace_none(a, b): obs = bimap_r(obs_zeros, obs, lambda _, o: np.array(o)) # datum that is not changed by training configuration - v = np.array([[replace_none(m['value'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) + v = np.array([replace_none(m['value'][m['turn'][0]], [0, 0]) for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) rew = np.array([[replace_none(m['reward'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) ret = np.array([[replace_none(m['return'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) @@ -91,12 +93,13 @@ def replace_none(a, b): # pad each array if step length is short batch_steps = args['burn_in_steps'] + args['forward_steps'] - if len(tmask) < batch_steps: + if len(tmask) < args['forward_steps']: pad_len_b = args['burn_in_steps'] - (ep['train_start'] - ep['start']) pad_len_a = batch_steps - len(tmask) - pad_len_b - obs = map_r(obs, lambda o: np.pad(o, [(pad_len_b, pad_len_a)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0)) + obs = map_r(obs, lambda o: np.pad(o, [(pad_len_a, pad_len_b)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0)) + p = np.pad(p, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) prob = np.pad(prob, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=1) - v = np.concatenate([np.pad(v, [(pad_len_b, 0), (0, 0), (0, 0)], 'constant', constant_values=0), np.tile(oc, [pad_len_a, 1, 1])]) + v = np.concatenate([np.pad(v, [(pad_len_b, 0), (0, 0), (0, 0)], 'constant', constant_values=0), np.tile(oc, [pad_len_a, 1, 1])]) act = np.pad(act, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) rew = np.pad(rew, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) ret = np.pad(ret, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) @@ -107,13 +110,14 @@ def replace_none(a, b): progress = np.pad(progress, [(pad_len_b, pad_len_a), (0, 0)], 'constant', constant_values=1) obss.append(obs) - datum.append((prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress)) + datum.append((p, prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress)) obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o))) - prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = [to_torch(np.array(val)) for val in zip(*datum)] + p, prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = [to_torch(np.array(val)) for val in zip(*datum)] return { 'observation': obs, + 'policy': p, 'selected_prob': prob, 'value': v, 'action': act, 'outcome': oc, @@ -150,38 +154,44 @@ def forward_prediction(model, hidden, batch, args): outputs = {} for t in range(batch_shape[1]): obs = map_r(observations, lambda o: o[:, t].flatten(0, 1)) # (..., B * P or 1, ...) - omask_ = batch['observation_mask'][:, t] - omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (h.dim() - 2)))) - hidden_ = bimap_r(hidden, omask, lambda h, m: h * m) # (..., B, P, ...) - if args['turn_based_training'] and not args['observation']: - hidden_ = map_r(hidden_, lambda h: h.sum(1)) # (..., B * 1, ...) - else: - hidden_ = map_r(hidden_, lambda h: h.flatten(0, 1)) # (..., B * P, ...) + action = batch['action'][:, t] + #omask_ = batch['observation_mask'][:, t] + #omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (h.dim() - 2)))) + #hidden_ = bimap_r(hidden, omask, lambda h, m: h * m) # (..., B, P, ...) + #if args['turn_based_training'] and not args['observation']: + # hidden_ = map_r(hidden_, lambda h: h.sum(1)) # (..., B * 1, ...) + #else: + # hidden_ = map_r(hidden_, lambda h: h.flatten(0, 1)) # (..., B * P, ...) + hidden_ = hidden if t < args['burn_in_steps']: model.eval() with torch.no_grad(): - outputs_ = model(obs, hidden_) + outputs_ = model(obs, hidden_, action) else: if not model.training: model.train() - outputs_ = model(obs, hidden_) + outputs_ = model(obs, hidden_, action) + next_hidden = outputs_.pop('hidden') outputs_ = map_r(outputs_, lambda o: o.unflatten(0, (batch_shape[0], batch_shape[2]))) # (..., B, P or 1, ...) for k, o in outputs_.items(): if k == 'hidden': next_hidden = o else: outputs[k] = outputs.get(k, []) + [o] - hidden = trimap_r(hidden, next_hidden, omask, lambda h, nh, m: h * (1 - m) + nh * m) + #hidden = trimap_r(hidden, next_hidden, omask, lambda h, nh, m: h * (1 - m) + nh * m) + hidden = next_hidden outputs = {k: torch.stack(o, dim=1) for k, o in outputs.items() if o[0] is not None} for k, o in outputs.items(): if k == 'policy': + # gather turn player's policies o = o.mul(batch['turn_mask']) if o.size(2) > 1 and batch_shape[2] == 1: # turn-alternating batch o = o.sum(2, keepdim=True) # gather turn player's policies - outputs[k] = o - batch['action_mask'] + outputs[k] = o # - batch['action_mask'] else: # mask valid target values and cumulative rewards + o = o.view(*batch['turn_mask'].size()[:2], -1, 1) outputs[k] = o.mul(batch['observation_mask']) return outputs @@ -201,15 +211,20 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba dcnt = tmasks.sum().item() losses['p'] = (-log_selected_policies * total_advantages).mul(tmasks).sum() + losses['pp'] = F.kl_div(F.log_softmax(outputs['policy'], -1), F.softmax(batch['policy'], -1), reduction='none').sum(-1, keepdim=True).mul(tmasks).sum() if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 + losses['pv'] = ((outputs['value'] - batch['outcome']) ** 2).mul(omasks).sum() / 2 if 'return' in outputs: losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum() - entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1)) + entropy = dist.Categorical(logits=outputs['policy'] - batch['action_mask']).entropy().mul(tmasks.sum(-1)) losses['ent'] = entropy.sum() - base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) + planning_weight = 1 + reinforce_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) + planning_loss = losses['pp'] + losses.get('pv', 0) + base_loss = (1 - planning_weight) * reinforce_loss + planning_weight * planning_loss entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] losses['total'] = base_loss + entropy_loss @@ -292,7 +307,8 @@ def select_episode(self): if random.random() < accept_rate: break ep = self.episodes[ep_idx] - turn_candidates = 1 + max(0, ep['steps'] - self.args['forward_steps']) # change start turn by sequence length + #turn_candidates = 1 + max(0, ep['steps'] - self.args['forward_steps']) # change start turn by sequence length + turn_candidates = ep['steps'] train_st = random.randrange(turn_candidates) st = max(0, train_st - self.args['burn_in_steps']) ed = min(train_st + self.args['forward_steps'], ep['steps']) diff --git a/tests/test_environment.py b/tests/test_environment.py index 14ff9f64..41a63162 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -7,7 +7,7 @@ ENVS = [ 'tictactoe', 'geister', - 'parallel_tictactoe', + #'parallel_tictactoe', 'kaggle.hungry_geese', ]