Skip to content
35 changes: 22 additions & 13 deletions handyrl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import time
import multiprocessing as mp

import numpy as np

from .environment import prepare_env, make_env
from .connection import send_recv, accept_socket_connections, connect_socket_connection
from .agent import RandomAgent, RuleBasedAgent, Agent, EnsembleAgent, SoftAgent
Expand Down Expand Up @@ -78,10 +80,12 @@ def observe(self, player):

def exec_match(env, agents, critic=None, show=False, game_args={}):
''' match with shared game environment '''
total_rewards = {}
if env.reset(game_args):
return None
for agent in agents.values():
for p, agent in agents.items():
agent.reset(env, show=show)
total_rewards[p] = 0
while not env.terminal():
if show:
view(env)
Expand All @@ -99,19 +103,23 @@ def exec_match(env, agents, critic=None, show=False, game_args={}):
return None
if show:
view_transition(env)
for p, reward in env.reward().items():
total_rewards[p] += np.array(reward).reshape(-1)
outcome = env.outcome()
if show:
print('final outcome = %s' % outcome)
return outcome
return {'outcome': outcome, 'total_reward': total_rewards}


def exec_network_match(env, network_agents, critic=None, show=False, game_args={}):
''' match with divided game environment '''
total_rewards = {}
if env.reset(game_args):
return None
for p, agent in network_agents.items():
info = env.diff_info(p)
agent.update(info, True)
total_rewards[p] = 0
while not env.terminal():
if show:
view(env)
Expand All @@ -128,13 +136,15 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={
agent.observe(p)
if env.step(actions):
return None
for p, reward in env.reward().items():
total_rewards[p] += np.array(reward).reshape(-1)
for p, agent in network_agents.items():
info = env.diff_info(p)
agent.update(info, False)
outcome = env.outcome()
for p, agent in network_agents.items():
agent.outcome(outcome[p])
return outcome
return {'outcome': outcome, 'total_reward': total_rewards}


def build_agent(raw, env):
Expand Down Expand Up @@ -165,11 +175,11 @@ def execute(self, models, args):
else:
agents[p] = Agent(model)

outcome = exec_match(self.env, agents)
if outcome is None:
result = exec_match(self.env, agents)
if result is None:
print('None episode in evaluation!')
return None
return {'args': args, 'result': outcome, 'opponent': opponent}
return {'args': args, 'opponent': opponent, **result}


def wp_func(results):
Expand All @@ -191,10 +201,10 @@ def eval_process_mp_child(agents, critic, env_args, index, in_queue, out_queue,
print('*** Game %d ***' % g)
agent_map = {env.players()[p]: agents[ai] for p, ai in enumerate(agent_ids)}
if isinstance(list(agent_map.values())[0], NetworkAgent):
outcome = exec_network_match(env, agent_map, critic, show=show, game_args=game_args)
result = exec_network_match(env, agent_map, critic, show=show, game_args=game_args)
else:
outcome = exec_match(env, agent_map, critic, show=show, game_args=game_args)
out_queue.put((pat_idx, agent_ids, outcome))
result = exec_match(env, agent_map, critic, show=show, game_args=game_args)
out_queue.put((pat_idx, agent_ids, result))
out_queue.put(None)


Expand Down Expand Up @@ -241,8 +251,9 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g
if ret is None:
finished_cnt += 1
continue
pat_idx, agent_ids, outcome = ret
if outcome is not None:
pat_idx, agent_ids, result = ret
if result is not None:
outcome = result['outcome']
for idx, p in enumerate(env.players()):
agent_id = agent_ids[idx]
oc = outcome[p]
Expand Down Expand Up @@ -303,7 +314,6 @@ def init_hidden(self):
hidden_inputs = [y for y in self.ort_session.get_inputs() if y.name.startswith('hidden')]
if len(hidden_inputs) == 0:
return None
import numpy as np
type_map = {
'tensor(float)': np.float32,
'tensor(int64)': np.int64,
Expand All @@ -319,7 +329,6 @@ def inference(self, x, hidden=None, batch_input=False):
ort_inputs = {}
ort_input_names = [y.name for y in self.ort_session.get_inputs()]

import numpy as np
def insert_input(y):
y = y if batch_input else np.expand_dims(y, 0)
ort_inputs[ort_input_names[len(ort_inputs)]] = y
Expand Down
6 changes: 5 additions & 1 deletion handyrl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def generate(self, models, args):
# episode generation
moments = []
hidden = {}
total_rewards = {}
for player in self.env.players():
hidden[player] = models[player].init_hidden()
total_rewards[player] = 0

err = self.env.reset()
if err:
Expand Down Expand Up @@ -66,6 +68,7 @@ def generate(self, models, args):
reward = self.env.reward()
for player in self.env.players():
moment['reward'][player] = reward.get(player, None)
total_rewards[player] += np.array(reward.get(player, 0)).reshape(-1)

moment['turn'] = turn_players
moments.append(moment)
Expand All @@ -76,12 +79,13 @@ def generate(self, models, args):
for player in self.env.players():
ret = 0
for i, m in reversed(list(enumerate(moments))):
ret = (m['reward'][player] or 0) + self.args['gamma'] * ret
ret = np.array(m['reward'][player] or 0) + np.array(self.args['gamma']) * ret
moments[i]['return'][player] = ret

episode = {
'args': args, 'steps': len(moments),
'outcome': self.env.outcome(),
'total_reward': total_rewards,
'moment': [
bz2.compress(pickle.dumps(moments[i:i+self.args['compress_steps']]))
for i in range(0, len(moments), self.args['compress_steps'])
Expand Down
2 changes: 1 addition & 1 deletion handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def feed_results(self, results):
continue
for p in result['args']['player']:
model_id = result['args']['model_id'][p]
res = result['result'][p]
res = result['outcome'][p]
n, r, r2 = self.results.get(model_id, (0, 0, 0))
self.results[model_id] = n + 1, r + res, r2 + res ** 2

Expand Down