From 8d9aa0daa2ed2640ba6ed6608ae3ee132715e6d6 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 13 Mar 2021 05:40:48 +0900 Subject: [PATCH 1/3] feature: return network instance from Environment.net() --- handyrl/envs/geister.py | 6 +++--- handyrl/envs/kaggle/hungry_geese.py | 6 +++--- handyrl/envs/tictactoe.py | 2 +- handyrl/train.py | 5 ++--- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index a0306d7d..c0b9e29b 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -129,11 +129,11 @@ def forward(self, x): class GeisterNet(nn.Module): - def __init__(self): + def __init__(self, obs): super().__init__() layers, filters, p_filters = 3, 32, 8 - input_channels = 7 + 18 # board channels + scalar inputs + input_channels = obs['scalar'].shape[0] + obs['board'].shape[0] self.input_size = (input_channels, 6, 6) self.conv1 = nn.Conv2d(input_channels, filters, kernel_size=3, stride=1, padding=1, bias=False) @@ -538,7 +538,7 @@ def observation(self, player=None): return {'scalar': s, 'board': b} def net(self): - return GeisterNet + return GeisterNet(self.observation()) if __name__ == '__main__': diff --git a/handyrl/envs/kaggle/hungry_geese.py b/handyrl/envs/kaggle/hungry_geese.py index 76d5f090..f895e811 100644 --- a/handyrl/envs/kaggle/hungry_geese.py +++ b/handyrl/envs/kaggle/hungry_geese.py @@ -36,11 +36,11 @@ def forward(self, x): class GeeseNet(nn.Module): - def __init__(self): + def __init__(self, obs): super().__init__() layers, filters = 12, 32 - self.conv0 = TorusConv2d(17, filters, (3, 3), True) + self.conv0 = TorusConv2d(obs.shape[0], filters, (3, 3), True) self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)]) self.head_p = nn.Linear(filters, 4, bias=False) self.head_v = nn.Linear(filters * 2, 1, bias=False) @@ -206,7 +206,7 @@ def rule_based_action(self, player): return self.ACTION.index(action) def net(self): - return GeeseNet + return GeeseNet(self.observation()) def observation(self, player=None): if player is None: diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index c6403b7f..8ad950d2 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -158,7 +158,7 @@ def players(self): return [0, 1] def net(self): - return SimpleConv2dModel + return SimpleConv2dModel() def observation(self, player=None): # input feature for neural nets diff --git a/handyrl/train.py b/handyrl/train.py index c29057c7..985c10fb 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -424,8 +424,7 @@ def __init__(self, args): # trained datum self.model_era = self.args['restart_epoch'] - self.model_class = self.env.net() - train_model = self.model_class() + train_model = self.env.net() if self.model_era == 0: self.model = RandomModel(self.env) else: @@ -584,7 +583,7 @@ def server(self): model = self.model if model_id != self.model_era: try: - model = self.model_class() + model = self.env.net() model.load_state_dict(torch.load(self.model_path(model_id)), strict=False) except: # return latest model if failed to load specified model From 21e5650835db432752c1532e769dc1217a5d76a5 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 30 Apr 2021 03:59:09 +0900 Subject: [PATCH 2/3] chore: stop using net() in loading old model from file --- handyrl/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index ac5dd375..757e72d6 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -604,7 +604,7 @@ def server(self): model = self.model if model_id != self.model_era: try: - model = self.env.net() + model = copy.deepcopy(self.model) model.load_state_dict(torch.load(self.model_path(model_id)), strict=False) except: # return latest model if failed to load specified model From 4b4527a0a66c4a84930a345cadc37c75c4c77956 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 30 Apr 2021 04:41:39 +0900 Subject: [PATCH 3/3] feature: enable to select optimizer in importing mode --- handyrl/train.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 757e72d6..4fea96be 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -18,7 +18,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.distributions as dist -import torch.optim as optim import psutil from .environment import prepare_env, make_env @@ -310,16 +309,21 @@ def shutdown(self): class Trainer: - def __init__(self, args, model): + def __init__(self, args, model, optim): self.episodes = deque() self.args = args self.gpu = torch.cuda.device_count() self.model = model - self.default_lr = 3e-8 - self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps'] - self.params = list(self.model.parameters()) - lr = self.default_lr * self.data_cnt_ema - self.optimizer = optim.Adam(self.params, lr=lr, weight_decay=1e-5) if len(self.params) > 0 else None + if optim is not None: + self.optim_selected = True + self.optim = optim + else: + self.optim_selected = False + self.default_lr = 3e-8 + self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps'] + lr = self.default_lr * self.data_cnt_ema + params = list(self.model.parameters()) + self.optim = torch.optim.Adam(params, lr=lr, weight_decay=1e-5) if len(params) > 0 else None self.steps = 0 self.lock = threading.Lock() self.batcher = Batcher(self.args, self.episodes) @@ -355,7 +359,7 @@ def shutdown(self): self.batcher.shutdown() def train(self): - if self.optimizer is None: # non-parametric model + if self.optim is None: # non-parametric model print() return @@ -379,10 +383,10 @@ def train(self): losses, dcnt = compute_loss(batch, train_model, hidden, self.args) - self.optimizer.zero_grad() + self.optim.zero_grad() losses['total'].backward() - nn.utils.clip_grad_norm_(self.params, 4.0) - self.optimizer.step() + nn.utils.clip_grad_norm_(model.parameters(), 4.0) + self.optim.step() batch_cnt += 1 data_cnt += dcnt @@ -393,9 +397,10 @@ def train(self): print('loss = %s' % ' '.join([k + ':' + '%.3f' % (l / data_cnt) for k, l in loss_sum.items()])) - self.data_cnt_ema = self.data_cnt_ema * 0.8 + data_cnt / (1e-2 + batch_cnt) * 0.2 - for param_group in self.optimizer.param_groups: - param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5) + if not self.optim_selected: + self.data_cnt_ema = self.data_cnt_ema * 0.8 + data_cnt / (1e-2 + batch_cnt) * 0.2 + for param_group in self.optim.param_groups: + param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5) self.model.cpu() self.model.eval() return copy.deepcopy(self.model) @@ -415,7 +420,7 @@ def run(self): class Learner: - def __init__(self, args, env=None, net=None, remote=False): + def __init__(self, args, env=None, net=None, remote=False, optim=None): train_args = args['train_args'] env_args = args['env_args'] train_args['env'] = env_args @@ -452,7 +457,7 @@ def __init__(self, args, env=None, net=None, remote=False): self.worker = WorkerCluster(args) # thread connection - self.trainer = Trainer(args, train_model) + self.trainer = Trainer(args, train_model, optim) def shutdown(self): self.shutdown_flag = True