Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4b1d168
experiment: MuZero implementation
YuriCat Jan 23, 2021
0638bfb
fix: add handyrl/search.py
YuriCat Jan 23, 2021
ee33e50
fix: policy target and action mask for target policy
YuriCat Jan 24, 2021
4f03150
fix: enable to use MuZero in evaluation mode
YuriCat Jan 24, 2021
c7adee0
feature: consider legal action set in entropy loss
YuriCat Jan 24, 2021
6376e83
feature: training after episode
YuriCat Jan 24, 2021
256fd1a
feature: upadte search.py
YuriCat Jan 24, 2021
482fa60
feature: TicTacToe from objective side
YuriCat Jan 24, 2021
ea48a45
Merge branch 'experiment/objective_tic_tac_toe' into experiment/objec…
YuriCat Jan 24, 2021
9e0224b
feature: turn-free MuZero implementation
YuriCat Jan 24, 2021
5848c78
chore: change MuZero head size
YuriCat Jan 25, 2021
2b78a36
fix: output multidimensional value
YuriCat Jan 25, 2021
2779f74
Merge develop
YuriCat Jan 31, 2021
6a62029
chore: updates planning losses notation
YuriCat Jan 31, 2021
c40fba9
chore: apply legal mask for policy in generator
YuriCat Jan 31, 2021
96d48b0
chore: update muzero model and search notation
YuriCat Jan 31, 2021
1a15619
Merge fix/random_model_outputs_float32_p_v
YuriCat Jan 31, 2021
537c204
fix: make_batch() for MuZero
YuriCat Jan 31, 2021
6cbe4c1
fix: fix of make_batch() for MuZero
YuriCat Jan 31, 2021
d87fad4
Merge branch 'develop' into experiment/muzero
YuriCat Feb 2, 2021
cd57e86
feature: player count in BaseModel
YuriCat Feb 14, 2021
9cea444
Merge develop
YuriCat Feb 14, 2021
1c35ff0
Merge branch 'feature/num_players_in_random_model' into experiment/mu…
YuriCat Feb 14, 2021
ec780a2
Merge develop
YuriCat Mar 4, 2021
18b0026
Merge develop
YuriCat Mar 8, 2021
c35ccf7
Merge branch 'develop' into experiment/muzero
YuriCat Mar 9, 2021
42270fd
fix: remove parallel tic-tac-toe from tests
YuriCat Mar 9, 2021
753d834
chore: all actions = num of players x action length
YuriCat Mar 9, 2021
32b2d5b
Merge develop
YuriCat Nov 3, 2021
8ce78cd
Merge develop
YuriCat Nov 18, 2021
83ee2dc
Merge feature/create_random_model
YuriCat Nov 18, 2021
fc76e52
feature: use F.kl_loss() to compute KL divergence
YuriCat Nov 18, 2021
aca6f6f
fix: outputted value of search function
YuriCat Nov 22, 2021
ccdd3ad
feature: residual net for MuZero
YuriCat Nov 23, 2021
141ca33
Merge develop
YuriCat Nov 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ train_args:
lambda: 0.7
policy_target: 'TD' # 'UPGO' 'VTRACE' 'TD' 'MC'
value_target: 'TD' # 'VTRACE' 'TD' 'MC'
policy_decay: 0.9
planning:
root_noise_alpha: 0.15
root_noise_coef: 0.25
eval:
opponent: ['random']
seed: 0
Expand Down
144 changes: 132 additions & 12 deletions handyrl/envs/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch.nn.functional as F

from ..environment import BaseEnvironment
from ..search import MonteCarloTree
from ..model import to_torch


class Conv(nn.Module):
Expand Down Expand Up @@ -69,6 +71,121 @@ def forward(self, x, hidden=None):
return {'policy': h_p, 'value': torch.tanh(h_v)}


class ResidualBlock(nn.Module):
def __init__(self, filters0, filters1):
super().__init__()
self.conv = Conv(filters0, filters1, 3, bn=True)

def forward(self, x):
h = self.conv(x)
return F.relu_(x + h)


class MuZero(nn.Module):
class Representation(nn.Module):
''' Conversion from observation to inner abstract state '''
def __init__(self, input_dim, layers, filters):
super().__init__()
self.layer0 = Conv(input_dim, filters, 3, bn=True)
self.blocks = nn.ModuleList([ResidualBlock(filters, filters) for _ in range(layers)])

def forward(self, x):
h = F.relu_(self.layer0(x))
for block in self.blocks:
h = block(h)
return h

def inference(self, x):
self.eval()
with torch.no_grad():
rp = self(to_torch(x).unsqueeze(0))
return rp.cpu().numpy().squeeze(0)

class Prediction(nn.Module):
''' Policy and value prediction from inner abstract state '''
def __init__(self, internal_size, num_players, action_length):
super().__init__()
self.head_p = Head(internal_size, 4, num_players * action_length)
self.head_v = Head(internal_size, 2, num_players)

def forward(self, rp):
p = self.head_p(rp)
v = self.head_v(rp)
return p, torch.tanh(v)

def inference(self, rp):
self.eval()
with torch.no_grad():
p, v = self(to_torch(rp).unsqueeze(0))
return p.cpu().numpy().squeeze(0), v.cpu().numpy().squeeze(0)

class Dynamics(nn.Module):
'''Abstract state transition'''
def __init__(self, rp_shape, layers, num_players, action_length, action_filters):
super().__init__()
self.action_shape = action_filters, rp_shape[1], rp_shape[2]
filters = rp_shape[0]
self.action_embedding = nn.Embedding(num_players * action_length, embedding_dim=np.prod(self.action_shape))
self.layer0 = Conv(filters + action_filters, filters, 3, bn=True)
self.blocks = nn.ModuleList([ResidualBlock(filters, filters) for _ in range(layers)])

def forward(self, rp, a):
arp = self.action_embedding(a).view(-1, *self.action_shape)
h = torch.cat([rp, arp], dim=1)
h = F.relu_(self.layer0(h))
for block in self.blocks:
h = block(h)
return h

def inference(self, rp, a):
self.eval()
with torch.no_grad():
rp = self(to_torch(rp).unsqueeze(0), to_torch(a).unsqueeze(0))
return rp.cpu().numpy().squeeze(0)

def __init__(self, env, obs, action_length):
super().__init__()
self.num_players = len(env.players())
self.action_length = action_length
self.input_size = obs.shape

layers, filters = 3, 32
action_filters = 4
internal_size = (filters, *self.input_size[1:])
self.planning_args = {
'root_noise_alpha': 0.15,
'root_noise_coef': 0.25,
}

self.nets = nn.ModuleDict({
'representation': self.Representation(self.input_size[0], layers, filters),
'prediction': self.Prediction(internal_size, self.num_players, self.action_length),
'dynamics': self.Dynamics(internal_size, layers, self.num_players, self.action_length, action_filters),
})

def init_hidden(self, batch_size=None):
return {}

def forward(self, x, hidden, action=None):
if 'representation' not in hidden:
rp = self.nets['representation'](x)
else:
rp = hidden['representation']
p, v = self.nets['prediction'](rp)
outputs = {'policy': p, 'value': v}

if action is not None:
next_rp = self.nets['dynamics'](rp, action)
outputs['hidden'] = {'representation': next_rp}
return outputs

def inference(self, x, hidden=None, num_simulations=30):
tree = MonteCarloTree(self.nets, self.planning_args)
p, v = tree.think(x, num_simulations)
return {'policy': p, 'value': v}



class Environment(BaseEnvironment):
X, Y = 'ABC', '123'
BLACK, WHITE = 1, -1
Expand All @@ -85,10 +202,12 @@ def reset(self, args=None):
self.record = []

def action2str(self, a, _=None):
return self.X[a // 3] + self.Y[a % 3]
pos = a % 9
return self.X[pos // 3] + self.Y[pos % 3]

def str2action(self, s, _=None):
return self.X.find(s[0]) * 3 + self.Y.find(s[1])
def str2action(self, s, player):
pos = self.X.find(s[0]) * 3 + self.Y.find(s[1])
return pos + 9 * player

def record_string(self):
return ' '.join([self.action2str(a) for a in self.record])
Expand All @@ -103,7 +222,8 @@ def __str__(self):
def play(self, action, _=None):
# state transition function
# action is integer (0 ~ 8)
x, y = action // 3, action % 3
pos = action % 9
x, y = pos // 3, pos % 3
self.board[x, y] = self.color

# check winning condition
Expand All @@ -127,7 +247,7 @@ def update(self, info, reset):
if reset:
self.reset()
else:
action = self.str2action(info)
action = self.str2action(info, self.turn())
self.play(action)

def turn(self):
Expand All @@ -148,22 +268,22 @@ def outcome(self):

def legal_actions(self, _=None):
# legal action list
return [a for a in range(3 * 3) if self.board[a // 3, a % 3] == 0]
player = self.turn()
return [pos + 9 * player for pos in range(3 * 3) if self.board[pos // 3, pos % 3] == 0]

def players(self):
return [0, 1]

def net(self):
return SimpleConv2dModel()
obs = self.observation(self.players()[0])
return MuZero(self, obs, 9)

def observation(self, player=None):
# input feature for neural nets
turn_view = player is None or player == self.turn()
color = self.color if turn_view else -self.color
a = np.stack([
np.ones_like(self.board) if turn_view else np.zeros_like(self.board),
self.board == color,
self.board == -color
np.ones_like(self.board) if self.turn() == 0 else np.zeros_like(self.board),
self.board == self.BLACK,
self.board == self.WHITE,
]).astype(np.float32)
return a

Expand Down
6 changes: 4 additions & 2 deletions handyrl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def generate(self, models, args):
return None

while not self.env.terminal():
moment_keys = ['observation', 'selected_prob', 'action_mask', 'action', 'value', 'reward', 'return']
moment_keys = ['observation', 'policy', 'selected_prob', 'action_mask', 'action', 'value', 'reward', 'return']
moment = {key: {p: None for p in self.env.players()} for key in moment_keys}
temperature = self.args['policy_decay'] ** len(moments)

turn_players = self.env.turns()
observers = self.env.observers()
Expand All @@ -54,9 +55,10 @@ def generate(self, models, args):
legal_actions = self.env.legal_actions(player)
action_mask = np.ones_like(p_) * 1e32
action_mask[legal_actions] = 0
p = softmax(p_ - action_mask)
p = softmax(p_ / temperature - action_mask)
action = random.choices(legal_actions, weights=p[legal_actions])[0]

moment['policy'][player] = p
moment['selected_prob'][player] = p[action]
moment['action_mask'][player] = action_mask
moment['action'][player] = action
Expand Down
1 change: 1 addition & 0 deletions handyrl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn.functional as F

from .util import map_r
from .search import MonteCarloTree


def to_torch(x):
Expand Down
111 changes: 111 additions & 0 deletions handyrl/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2020 DeNA Co., Ltd.
# Licensed under The MIT License [see LICENSE for details]

# tree search

import copy
import time

import numpy as np


def softmax(x):
x = np.exp(x - np.max(x, axis=-1))
return x / x.sum(axis=-1)


class Node:
'''Search result of one abstract (or root) state'''
def __init__(self, p, v):
self.p, self.v = p, v
self.n = np.zeros_like(p)
self.q_sum = np.zeros((*p.shape, v.shape[-1]), dtype=np.float32)
self.n_all, self.q_sum_all = 1, v / 2 # prior

def update(self, action, q_new):
# Update
self.n[action] += 1
self.q_sum[action] += q_new

# Update overall stats
self.n_all += 1
self.q_sum_all += q_new


class MonteCarloTree:
'''Monte Carlo Tree Search'''
def __init__(self, model, args):
self.model = model
self.args = args
self.nodes = {}

def search(self, rp, path):
# Return predicted value from new state
key = '|' + ' '.join(map(str, path))
if key not in self.nodes:
p, v = self.model['prediction'].inference(rp)
p, v = softmax(p), v
self.nodes[key] = Node(p, v)
return v

# Choose action with bandit
node = self.nodes[key]
p = node.p
if len(path) == 0:
# Add noise to policy on the root node
noise = np.random.dirichlet([self.args['root_noise_alpha']] * np.prod(p.shape)).reshape(*p.shape)
p = (1 - self.args['root_noise_coef']) * p + self.args['root_noise_coef'] * noise
# On the root node, we choose action only from legal actions
p /= p.sum() + 1e-16

q_mean_all = node.q_sum_all.reshape(1, -1) / node.n_all
n, q_sum = 1 + node.n, q_mean_all + node.q_sum
adv = (q_sum / n.reshape(-1, 1) - q_mean_all).reshape(q_sum.shape[-1], -1, q_sum.shape[-1])
adv = np.concatenate([adv[0, :, 0], adv[1, :, 1]])
ucb = adv + 2.0 * np.sqrt(node.n_all) * p / n # PUCB formula
selected_action = np.argmax(ucb)

# Search next state by recursively calling this function
next_rp = self.model['dynamics'].inference(rp, np.array([selected_action]))
path.append(selected_action)
q_new = self.search(next_rp, path)
node.update(selected_action, q_new)

return q_new

def think(self, root_obs, num_simulations, env=None, show=False):
# End point of MCTS
start, prev_time = time.time(), 0
for _ in range(num_simulations):
self.search(self.model['representation'].inference(root_obs), [])

# Display search result on every second
if show:
tmp_time = time.time() - start
if int(tmp_time) > int(prev_time):
prev_time = tmp_time
root, pv = self.nodes['|'], self.pv(env)
print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s'
% (tmp_time, env.action2str(pv[0][0], pv[0][1]), root.q_sum[pv[0][0]] / root.n[pv[0][0]],
root.n[pv[0][0]], root.n_all, ' '.join([env.action2str(a, p) for a, p in pv])))

# Return probability distribution weighted by the number of simulations
root = self.nodes['|']
n = root.n + 1e-4
p = n / n.sum()
v = ((root.q_sum / (root.n.reshape(-1, 1) + 1e-4)) * p.reshape(-1, 1)).sum(0)
return np.log(p), v

def pv(self, env_):
# Return principal variation (action sequence which is considered as the best)
env = copy.deepcopy(env_)
pv_seq = []
while True:
path = list(zip(*pv_seq))[0]
key = '|' + ' '.join(map(str, path))
if key not in self.nodes or self.nodes[key].n.sum() == 0:
break
best_action = sorted([(a, self.nodes[key].n[a]) for a in env.legal_actions()], key=lambda x: -x[1])[0][0]
pv_seq.append((best_action, env.turn()))
env.play(best_action)
return pv_seq
Loading