Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
28f960c
feature: burn-in steps
YuriCat Mar 12, 2021
a52e9dd
feature: set model.eval() in burn_in steps
YuriCat Mar 12, 2021
26ecce0
Merge branch 'develop' into feature/burn_in_steps
YuriCat Mar 17, 2021
e01b8ff
Merge branch 'develop' into feature/burn_in_steps
YuriCat May 7, 2021
2bfcef8
Merge develop
YuriCat Oct 18, 2021
4b71206
feature: add observers() method in environments
YuriCat Oct 20, 2021
e519e10
Merge branch 'develop' into feature/burn_in_steps
YuriCat Oct 20, 2021
23ff9f9
feature: steps = burn_in_steps + forward_steps
YuriCat Oct 20, 2021
6151ea7
feature: unnecessary hidden argument
YuriCat Nov 9, 2021
084bda4
Merge branch 'develop' into feature/burn_in_steps
YuriCat Nov 18, 2021
d83b33f
Merge branch 'develop' into feature/burn_in_steps
YuriCat Nov 26, 2021
3f9e0af
Merge branch 'feature/burn_in_steps' into feautre/burn_in_step_naive_…
YuriCat Nov 26, 2021
216734e
chore: check whether burn_in_steps > 0 before slicing batch and outputs
YuriCat Nov 26, 2021
858dd0b
Merge branch 'develop' into feature/observer_determined_in_env
YuriCat Jan 6, 2022
22ac16a
Merge branch 'develop' into feature/hidden_hidden
YuriCat Jan 8, 2022
126399a
feature: use load_model() function
YuriCat Jan 9, 2022
cfc1a39
feature: store selected probability only (IMPORTANT)
YuriCat Jan 9, 2022
e4a74c1
feature: add OnnxModel class to infer with ONNX models
YuriCat Jan 11, 2022
5900c30
fix: load ONNX model when selecting ONNX model path
YuriCat Jan 11, 2022
1cb0176
feature: support hidden state inputs in OnnxModel
YuriCat Jan 12, 2022
6bc68be
feature: link the 1st place solution in Hungry Geese competition
YuriCat Jan 12, 2022
f60a1ab
Merge branch 'develop' into feature/burn_in_steps
YuriCat Jan 12, 2022
b68c737
chore: set default burn_in_steps=0
YuriCat Jan 12, 2022
57b6a54
chore: update redundunt code when receiving hidden outputs
YuriCat Jan 12, 2022
74ff0f2
Merge pull request #235 from YuriCat/feature/update_use_cases_hg
ikki407 Jan 13, 2022
e8dcb8e
feature: true DRC reffered to PyTorch official LSTM implementation
YuriCat Jan 14, 2022
12cd56f
feature: add an explanation of DRC net and a link to the paper
YuriCat Jan 15, 2022
3ccdc94
chore: remove unused import
YuriCat Jan 15, 2022
ec5b317
Merge branch 'develop' into feature/behavior_policy_selected_only
YuriCat Jan 16, 2022
b697ea1
chore: update legal action check in Geister
YuriCat Jan 17, 2022
3b0e82b
Merge pull request #238 from YuriCat/feature/explain_drc
ikki407 Jan 17, 2022
60e602d
fix: return model even if using non-trainable models
YuriCat Jan 21, 2022
bcb4b37
feature: (idea) create observation mask from observation!=None
YuriCat Jan 21, 2022
99c0275
Merge branch 'develop' into feature/observer_determined_in_env
YuriCat Jan 21, 2022
c1d0235
fix: observers() explanation
YuriCat Jan 22, 2022
2446802
chore: update generation player loop
YuriCat Jan 22, 2022
1f3f3e2
feature: there is cases if we have no value outputs nor policy outputs
YuriCat Jan 22, 2022
5b840db
Merge branch 'develop' into feature/load_model
YuriCat Jan 22, 2022
4e34681
feature: load_model(model_path, env.net()) looks cool and easy
YuriCat Jan 22, 2022
46560e8
chore: fix typo chainge
YuriCat Jan 22, 2022
2aeb7e2
fix: eval-client mode
YuriCat Jan 22, 2022
951d363
feature: define only torch-based initialized hidden state
YuriCat Jan 23, 2022
2bf4b1a
Merge pull request #240 from YuriCat/fix/create_observation_mask
ikki407 Jan 24, 2022
5dd5302
Merge pull request #242 from YuriCat/fix/if_there_is_no_v_or_prob
ikki407 Jan 24, 2022
e4adeaf
Merge pull request #237 from YuriCat/feature/true_drc
ikki407 Jan 24, 2022
57a1967
Merge pull request #231 from YuriCat/feature/load_model
ikki407 Jan 24, 2022
2b361ff
Merge develop
YuriCat Jan 24, 2022
fffc42d
chore: change the position of debug output
YuriCat Jan 24, 2022
897ea32
feature: assing cumulative worker index
YuriCat Jan 24, 2022
8dbf79a
chore: remove ugly comment
YuriCat Jan 24, 2022
dc6410f
Merge pull request #239 from YuriCat/fix/trainer_for_non_trainable_model
ikki407 Jan 25, 2022
fd04681
Merge pull request #243 from YuriCat/feature/init_hidden_torch
ikki407 Jan 25, 2022
303bfdb
Merge pull request #236 from YuriCat/chore/202201
ikki407 Jan 25, 2022
a7f13b7
Merge pull request #216 from YuriCat/feature/observer_determined_in_env
ikki407 Jan 25, 2022
3db8376
chore: update description of observers()
YuriCat Jan 25, 2022
bf30631
Merge branch 'develop' into feature/burn_in_steps
YuriCat Jan 25, 2022
d2b0dc4
feature: add make_onnx_model.py
YuriCat Jan 25, 2022
154d283
feature: add scripts/aux_swa.py
YuriCat Jan 25, 2022
35406b1
Merge develop
YuriCat Jan 25, 2022
604e385
Merge branch 'develop' into feature/hidden_hidden
YuriCat Jan 25, 2022
acd77f7
feature: accept smaller number of arguments by checking argument count
YuriCat Jan 25, 2022
4fab026
chore: change variable names
YuriCat Jan 25, 2022
28bdc68
Merge pull request #250 from YuriCat/chore/observers_description
ikki407 Jan 26, 2022
261b4d6
Merge develop
YuriCat Jan 26, 2022
3b29cdf
Merge pull request #233 from YuriCat/feature/onnx_model
ikki407 Jan 26, 2022
2580165
Merge pull request #253 from YuriCat/feature/hidden_hidden_argument_c…
ikki407 Jan 26, 2022
190852a
feature: remove duplicate loading and unnecessary ModelWrapper
YuriCat Jan 26, 2022
09f74c7
fix: use turns()[0] instead of turn()
YuriCat Jan 26, 2022
76a3b00
Merge pull request #223 from YuriCat/feature/burn_in_steps
ikki407 Jan 26, 2022
050cc69
Merge pull request #251 from YuriCat/feature/add_onnx_script
ikki407 Jan 26, 2022
f94a0c2
Merge develop
YuriCat Jan 26, 2022
c804ce7
chore: fix typo
YuriCat Jan 26, 2022
0ad9707
Merge pull request #232 from YuriCat/feature/behavior_policy_selected…
ikki407 Jan 26, 2022
ddcb6f7
fix: set skip setting from argument
YuriCat Jan 26, 2022
b7397bc
Revert "feature: hidden hidden argument count"
ikki407 Jan 27, 2022
8d7e653
chore: style fix in scripts/make_onnx_model.py
YuriCat Jan 27, 2022
23b76f3
Merge pull request #257 from DeNA/revert-253-feature/hidden_hidden_ar…
ikki407 Jan 27, 2022
2a8f795
Revert "chore: fix typo"
YuriCat Jan 27, 2022
24bbfae
feature: add scripts/win_rate_plot.py
YuriCat Jan 28, 2022
48d8899
Merge branch 'develop' into chore/202201_2
YuriCat Jan 30, 2022
e023aa0
feature: use Tensor.dim() instead of len(Tensor.size())
YuriCat Jan 30, 2022
1f15ba9
chore: style fix in train.py
YuriCat Jan 30, 2022
8a28df7
feature: set batch_shape, and use flatten and unflatten
YuriCat Jan 30, 2022
927d7eb
feature: fix dimension description
YuriCat Jan 30, 2022
f1b57e3
chore: use flatten in feed forward computation
YuriCat Jan 30, 2022
037ae15
chore: use flatten in RNN computation
YuriCat Jan 30, 2022
62e910a
feature: save_model.pth.onnx -> saved_model.onnx
YuriCat Jan 31, 2022
163dc58
Merge pull request #263 from YuriCat/feature/simple_save_path_onnx
ikki407 Jan 31, 2022
233f11c
Merge pull request #246 from YuriCat/feature/cumulative_worker_id
ikki407 Jan 31, 2022
20d059c
chore: paper -> academic paper
YuriCat Jan 31, 2022
412df5d
Merge pull request #258 from YuriCat/feature/add_win_rate_plot
ikki407 Jan 31, 2022
0123390
Merge pull request #252 from YuriCat/feature/add_swa_script
ikki407 Jan 31, 2022
6740725
feature: add batch-shape description for feedforward net
YuriCat Jan 31, 2022
be49d4b
Merge branch 'develop' into feature/batch_shape_from_action
YuriCat Jan 31, 2022
4a7aee1
Merge pull request #260 from YuriCat/feature/batch_shape_from_action
ikki407 Jan 31, 2022
4ffcd73
Merge pull request #256 from YuriCat/chore/202201_2
ikki407 Jan 31, 2022
bfb56ba
fix: consider the case when training policies of two players
YuriCat Jan 31, 2022
c5dd1c4
chore: change position of comment
YuriCat Jan 31, 2022
698662a
Merge pull request #264 from YuriCat/fix/policy_sum
ikki407 Jan 31, 2022
e43df89
chore: since critic is not necessary, it should be None at default
YuriCat Feb 1, 2022
39cfe4c
chore: remove unused imports from scripts/win_rate_plot.py
YuriCat Feb 5, 2022
af94d9b
chore: remove unused import from connection.py
YuriCat Feb 5, 2022
35603bd
Merge pull request #266 from YuriCat/chore/202202
ikki407 Feb 8, 2022
9aa67db
Merge pull request #271 from YuriCat/chore/critic_default_none
ikki407 Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,5 @@ NOTE: Default opponent AI is random agent implemented in `evaluation.py`. You ca

## Use Cases

* [Month 1 Winner in Hungry Geese (Kaggle)](https://www.kaggle.com/c/hungry-geese/discussion/222941)
* [The 5th solution in Google Research Football with Manchester City F.C. (Kaggle)](https://www.kaggle.com/c/google-football/discussion/203412)
* [The 1st place solution in Hungry Geese (Kaggle)](https://www.kaggle.com/c/hungry-geese/discussion/263279)
* [The 5th place solution in Google Research Football with Manchester City F.C. (Kaggle)](https://www.kaggle.com/c/google-football/discussion/203412)
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ train_args:
observation: False
gamma: 0.8
forward_steps: 16
burn_in_steps: 0 # for RNNs
compress_steps: 4
entropy_regularization: 1.0e-1
entropy_regularization_decay: 0.1
Expand Down
23 changes: 11 additions & 12 deletions handyrl/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@ def print_outputs(env, prob, v):
if hasattr(env, 'print_outputs'):
env.print_outputs(prob, v)
else:
print('v = %f' % v)
print('p = %s' % (prob * 1000).astype(int))
if v is not None:
print('v = %f' % v)
if prob is not None:
print('p = %s' % (prob * 1000).astype(int))


class Agent:
def __init__(self, model, observation=False, temperature=0.0):
def __init__(self, model, temperature=0.0):
# model might be a neural net, or some planning algorithm such as game tree search
self.model = model
self.hidden = None
self.observation = observation
self.temperature = temperature

def reset(self, env, show=False):
Expand Down Expand Up @@ -73,12 +74,10 @@ def action(self, env, player, show=False):
return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0]

def observe(self, env, player, show=False):
v = None
if self.observation:
outputs = self.plan(env.observation(player))
v = outputs.get('value', None)
if show:
print_outputs(env, None, v)
outputs = self.plan(env.observation(player))
v = outputs.get('value', None)
if show:
print_outputs(env, None, v)
return v if v is not None else [0.0]


Expand All @@ -101,5 +100,5 @@ def plan(self, obs):


class SoftAgent(Agent):
def __init__(self, model, observation=False):
super().__init__(model, observation=observation, temperature=1.0)
def __init__(self, model):
super().__init__(model, temperature=1.0)
1 change: 0 additions & 1 deletion handyrl/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under The MIT License [see LICENSE for details]

import io
import time
import struct
import socket
import pickle
Expand Down
7 changes: 7 additions & 0 deletions handyrl/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def turn(self):
def turns(self):
return [self.turn()]

#
# Should be defined if there are other players besides the turn player
# who should observe the environment (mainly with RNNs)
#
def observers(self):
return []

#
# Should be defined in all games
#
Expand Down
25 changes: 13 additions & 12 deletions handyrl/envs/geister.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,10 @@ def __init__(self, input_dim, hidden_dim, kernel_size, bias):
)

def init_hidden(self, input_size, batch_size):
if batch_size is None: # for inference
return tuple([
np.zeros((self.hidden_dim, *input_size), dtype=np.float32),
np.zeros((self.hidden_dim, *input_size), dtype=np.float32)
])
else: # for training
return tuple([
torch.zeros(*batch_size, self.hidden_dim, *input_size),
torch.zeros(*batch_size, self.hidden_dim, *input_size)
])
return tuple([
torch.zeros(*batch_size, self.hidden_dim, *input_size),
torch.zeros(*batch_size, self.hidden_dim, *input_size)
])

def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
Expand All @@ -63,6 +57,11 @@ def forward(self, input_tensor, cur_state):
return h_next, c_next


# Deep Repeated Conv-LSTM (https://arxiv.org/abs/1901.03559)
# increases expressive power with fewer parameters
# by repeatedly computing multi-layer convolutional LSTM.
# When num_repeats=1, it is simply a multi-layer Conv-LSTM.

class DRC(nn.Module):
def __init__(self, num_layers, input_dim, hidden_dim, kernel_size=3, bias=True):
super().__init__()
Expand Down Expand Up @@ -93,7 +92,7 @@ def forward(self, x, hidden, num_repeats):
hs, cs = hidden
for _ in range(num_repeats):
for i, block in enumerate(self.blocks):
hs[i], cs[i] = block(x, (hs[i], cs[i]))
hs[i], cs[i] = block(hs[i - 1] if i > 0 else x, (hs[i], cs[i]))

return hs[-1], (hs, cs)

Expand Down Expand Up @@ -145,7 +144,7 @@ def __init__(self):
self.head_v = ScalarHead((filters * 2, 6, 6), 1, 1)
self.head_r = ScalarHead((filters * 2, 6, 6), 1, 1)

def init_hidden(self, batch_size=None):
def init_hidden(self, batch_size=[]):
return self.body.init_hidden(self.input_size[1:], batch_size)

def forward(self, x, hidden):
Expand Down Expand Up @@ -448,6 +447,8 @@ def legal(self, action):
if self.turn_count < 0:
layout = action - 4 * 6 * 6
return 0 <= layout < 70
elif not 0 <= action < 4 * 6 * 6:
return False

pos_from = self.action2from(action, self.color)
pos_to = self.action2to(action, self.color)
Expand Down
91 changes: 81 additions & 10 deletions handyrl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def observe(self, player):
return send_recv(self.conn, ('observe', [player]))


def exec_match(env, agents, critic, show=False, game_args={}):
def exec_match(env, agents, critic=None, show=False, game_args={}):
''' match with shared game environment '''
if env.reset(game_args):
return None
Expand All @@ -88,11 +88,12 @@ def exec_match(env, agents, critic, show=False, game_args={}):
if show and critic is not None:
print('cv = ', critic.observe(env, None, show=False)[0])
turn_players = env.turns()
observers = env.observers()
actions = {}
for p, agent in agents.items():
if p in turn_players:
actions[p] = agent.action(env, p, show=show)
else:
elif p in observers:
agent.observe(env, p, show=show)
if env.step(actions):
return None
Expand All @@ -104,7 +105,7 @@ def exec_match(env, agents, critic, show=False, game_args={}):
return outcome


def exec_network_match(env, network_agents, critic, show=False, game_args={}):
def exec_network_match(env, network_agents, critic=None, show=False, game_args={}):
''' match with divided game environment '''
if env.reset(game_args):
return None
Expand All @@ -117,12 +118,13 @@ def exec_network_match(env, network_agents, critic, show=False, game_args={}):
if show and critic is not None:
print('cv = ', critic.observe(env, None, show=False)[0])
turn_players = env.turns()
observers = env.observers()
actions = {}
for p, agent in network_agents.items():
if p in turn_players:
action = agent.action(p)
actions[p] = env.str2action(action, p)
else:
elif p in observers:
agent.observe(p)
if env.step(actions):
return None
Expand Down Expand Up @@ -161,9 +163,9 @@ def execute(self, models, args):
if model is None:
agents[p] = build_agent(opponent, self.env)
else:
agents[p] = Agent(model, self.args['observation'])
agents[p] = Agent(model)

outcome = exec_match(self.env, agents, None)
outcome = exec_match(self.env, agents)
if outcome is None:
print('None episode in evaluation!')
return None
Expand Down Expand Up @@ -277,10 +279,78 @@ def network_match_acception(n, env_args, num_agents, port):
return agents_list


def get_model(env, model_path):
class OnnxModel:
def __init__(self, model_path):
self.model_path = model_path
self.ort_session = None

def _open_session(self):
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OMP_WAIT_POLICY'] = 'PASSIVE'

import onnxruntime
opts = onnxruntime.SessionOptions()
opts.intra_op_num_threads = 1
opts.inter_op_num_threads = 1
opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL

self.ort_session = onnxruntime.InferenceSession(self.model_path, sess_options=opts)

def init_hidden(self):
if self.ort_session is None:
self._open_session()
hidden_inputs = [y for y in self.ort_session.get_inputs() if y.name.startswith('hidden')]
if len(hidden_inputs) == 0:
return None
import numpy as np
type_map = {
'tensor(float)': np.float32,
'tensor(int64)': np.int64,
}
hidden_tensors = [np.zeros(y.shape[1:], dtype=type_map[y.type]) for y in hidden_inputs]
return hidden_tensors

def inference(self, x, hidden=None, batch_input=False):
# numpy array -> numpy array
if self.ort_session is None:
self._open_session()

ort_inputs = {}
ort_input_names = [y.name for y in self.ort_session.get_inputs()]

import numpy as np
def insert_input(y):
y = y if batch_input else np.expand_dims(y, 0)
ort_inputs[ort_input_names[len(ort_inputs)]] = y
from .util import map_r
map_r(x, lambda y: insert_input(y))
if hidden is not None:
map_r(hidden, lambda y: insert_input(y))
ort_outputs = self.ort_session.run(None, ort_inputs)
if not batch_input:
ort_outputs = [o.squeeze(0) for o in ort_outputs]

ort_output_names = [y.name for y in self.ort_session.get_outputs()]
outputs = {name: ort_outputs[i] for i, name in enumerate(ort_output_names)}

hidden_outputs = []
for k in list(outputs.keys()):
if k.startswith('hidden'):
hidden_outputs.append(outputs.pop(k))
if len(hidden_outputs) == 0:
hidden_outputs = None

outputs = {**outputs, 'hidden': hidden_outputs}
return outputs


def load_model(model_path, model):
if model_path.endswith('.onnx'):
model = OnnxModel(model_path)
return model
import torch
from .model import ModelWrapper
model = env.net()
model.load_state_dict(torch.load(model_path))
model.eval()
return ModelWrapper(model)
Expand All @@ -290,7 +360,7 @@ def client_mp_child(env_args, model_path, conn):
env = make_env(env_args)
agent = build_agent(model_path, env)
if agent is None:
model = get_model(env, model_path)
model = load_model(model_path, env.net())
agent = Agent(model)
NetworkAgentClient(agent, env, conn).run()

Expand All @@ -306,7 +376,8 @@ def eval_main(args, argv):

agent1 = build_agent(model_path, env)
if agent1 is None:
agent1 = Agent(get_model(env, model_path))
model = load_model(model_path, env.net())
agent1 = Agent(model)
critic = None

print('%d process, %d games' % (num_process, num_games))
Expand Down
47 changes: 25 additions & 22 deletions handyrl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,35 @@ def generate(self, models, args):
return None

while not self.env.terminal():
moment_keys = ['observation', 'policy', 'action_mask', 'action', 'value', 'reward', 'return']
moment_keys = ['observation', 'selected_prob', 'action_mask', 'action', 'value', 'reward', 'return']
moment = {key: {p: None for p in self.env.players()} for key in moment_keys}

turn_players = self.env.turns()
observers = self.env.observers()
for player in self.env.players():
if player in turn_players or self.args['observation']:
obs = self.env.observation(player)
model = models[player]
outputs = model.inference(obs, hidden[player])
hidden[player] = outputs.get('hidden', None)
v = outputs.get('value', None)

moment['observation'][player] = obs
moment['value'][player] = v

if player in turn_players:
p_ = outputs['policy']
legal_actions = self.env.legal_actions(player)
action_mask = np.ones_like(p_) * 1e32
action_mask[legal_actions] = 0
p = p_ - action_mask
action = random.choices(legal_actions, weights=softmax(p[legal_actions]))[0]

moment['policy'][player] = p
moment['action_mask'][player] = action_mask
moment['action'][player] = action
if player not in turn_players + observers:
continue

obs = self.env.observation(player)
model = models[player]
outputs = model.inference(obs, hidden[player])
hidden[player] = outputs.get('hidden', None)
v = outputs.get('value', None)

moment['observation'][player] = obs
moment['value'][player] = v

if player in turn_players:
p_ = outputs['policy']
legal_actions = self.env.legal_actions(player)
action_mask = np.ones_like(p_) * 1e32
action_mask[legal_actions] = 0
p = softmax(p_ - action_mask)
action = random.choices(legal_actions, weights=p[legal_actions])[0]

moment['selected_prob'][player] = p[action]
moment['action_mask'][player] = action_mask
moment['action'][player] = action

err = self.env.step(moment['action'])
if err:
Expand Down
6 changes: 5 additions & 1 deletion handyrl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def __init__(self, model):

def init_hidden(self, batch_size=None):
if hasattr(self.model, 'init_hidden'):
return self.model.init_hidden(batch_size)
if batch_size is None: # for inference
hidden = self.model.init_hidden([])
return map_r(hidden, lambda h: h.detach().numpy() if isinstance(h, torch.Tensor) else h)
else: # for training
return self.model.init_hidden(batch_size)
return None

def forward(self, *args, **kwargs):
Expand Down
Loading