diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index 71114e27..b8dc0e75 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -134,11 +134,11 @@ def forward(self, x): class GeisterNet(nn.Module): - def __init__(self): + def __init__(self, obs): super().__init__() layers, filters, p_filters = 3, 32, 8 - input_channels = 7 + 18 # board channels + scalar inputs + input_channels = obs['scalar'].shape[0] + obs['board'].shape[0] self.input_size = (input_channels, 6, 6) self.conv1 = nn.Conv2d(input_channels, filters, kernel_size=3, stride=1, padding=1, bias=False) @@ -539,7 +539,7 @@ def observation(self, player=None): return {'scalar': s, 'board': b} def net(self): - return GeisterNet() + return GeisterNet(self.observation()) if __name__ == '__main__': diff --git a/handyrl/envs/kaggle/hungry_geese.py b/handyrl/envs/kaggle/hungry_geese.py index 0a663adc..f7cd973a 100644 --- a/handyrl/envs/kaggle/hungry_geese.py +++ b/handyrl/envs/kaggle/hungry_geese.py @@ -36,11 +36,11 @@ def forward(self, x): class GeeseNet(nn.Module): - def __init__(self): + def __init__(self, obs): super().__init__() layers, filters = 12, 32 - self.conv0 = TorusConv2d(17, filters, (3, 3), True) + self.conv0 = TorusConv2d(obs.shape[0], filters, (3, 3), True) self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)]) self.head_p = nn.Linear(filters, 4, bias=False) self.head_v = nn.Linear(filters * 2, 1, bias=False) @@ -197,7 +197,7 @@ def rule_based_action(self, player): return self.ACTION.index(action) def net(self): - return GeeseNet() + return GeeseNet(self.observation()) def observation(self, player=None): if player is None: diff --git a/handyrl/train.py b/handyrl/train.py index b933ebaa..00fcea0b 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -18,7 +18,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.distributions as dist -import torch.optim as optim import psutil from .environment import prepare_env, make_env @@ -298,16 +297,21 @@ def shutdown(self): class Trainer: - def __init__(self, args, model): + def __init__(self, args, model, optim): self.episodes = deque() self.args = args self.gpu = torch.cuda.device_count() self.model = model - self.default_lr = 3e-8 - self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps'] - self.params = list(self.model.parameters()) - lr = self.default_lr * self.data_cnt_ema - self.optimizer = optim.Adam(self.params, lr=lr, weight_decay=1e-5) if len(self.params) > 0 else None + if optim is not None: + self.optim_selected = True + self.optim = optim + else: + self.optim_selected = False + self.default_lr = 3e-8 + self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps'] + lr = self.default_lr * self.data_cnt_ema + params = list(self.model.parameters()) + self.optim = torch.optim.Adam(params, lr=lr, weight_decay=1e-5) if len(params) > 0 else None self.steps = 0 self.lock = threading.Lock() self.batcher = Batcher(self.args, self.episodes) @@ -348,7 +352,7 @@ def shutdown(self): self.batcher.shutdown() def train(self): - if self.optimizer is None: # non-parametric model + if self.optim is None: # non-parametric model print() return @@ -368,10 +372,10 @@ def train(self): losses, dcnt = compute_loss(batch, self.trained_model, hidden, self.args) - self.optimizer.zero_grad() + self.optim.zero_grad() losses['total'].backward() - nn.utils.clip_grad_norm_(self.params, 4.0) - self.optimizer.step() + nn.utils.clip_grad_norm_(self.model.parameters(), 4.0) + self.optim.step() batch_cnt += 1 data_cnt += dcnt @@ -382,9 +386,10 @@ def train(self): print('loss = %s' % ' '.join([k + ':' + '%.3f' % (l / data_cnt) for k, l in loss_sum.items()])) - self.data_cnt_ema = self.data_cnt_ema * 0.8 + data_cnt / (1e-2 + batch_cnt) * 0.2 - for param_group in self.optimizer.param_groups: - param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5) + if not self.optim_selected: + self.data_cnt_ema = self.data_cnt_ema * 0.8 + data_cnt / (1e-2 + batch_cnt) * 0.2 + for param_group in self.optim.param_groups: + param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5) self.model.cpu() self.model.eval() return copy.deepcopy(self.model) @@ -404,7 +409,7 @@ def run(self): class Learner: - def __init__(self, args, net=None, remote=False): + def __init__(self, args, net=None, remote=False, optim=None): train_args = args['train_args'] env_args = args['env_args'] train_args['env'] = env_args @@ -438,7 +443,7 @@ def __init__(self, args, net=None, remote=False): self.worker = WorkerServer(args) if remote else WorkerCluster(args) # thread connection - self.trainer = Trainer(args, self.model) + self.trainer = Trainer(args, self.model, optim) def shutdown(self): self.shutdown_flag = True