-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
97 lines (71 loc) · 2.66 KB
/
Copy pathutils.py
File metadata and controls
97 lines (71 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Useful functions."""
import importlib
from functools import partial
from typing import Tuple, TypeVar
import pickle
import os
import chex
import jax
import jax.numpy as jnp
import pax
from games.env import Enviroment as E
T = TypeVar("T")
@pax.pure
def batched_policy(agent, states):
"""Apply a policy to a batch of states.
Also return the updated agent.
"""
return agent, agent(states, batched=True)
def replicate(value: T, repeat: int) -> T:
"""Replicate along the first axis."""
return jax.tree_util.tree_map(lambda x: jnp.stack([x] * repeat), value)
@pax.pure
def reset_env(env: E) -> E:
"""Return a reset enviroment."""
env.reset()
return env
@jax.jit
def env_step(env: E, action: chex.Array) -> Tuple[E, chex.Array]:
"""Execute one step in the enviroment."""
env, reward = env.step(action)
return env, reward
def import_class(path: str) -> E:
"""Import a class from a python file.
For example:
>> Game = import_class("connect_two_game.Connect2Game")
Game is the Connect2Game class from `connection_two_game.py`.
"""
names = path.split(".")
mod_path, class_name = names[:-1], names[-1]
mod = importlib.import_module(".".join(mod_path))
return getattr(mod, class_name)
def select_tree(pred: jnp.ndarray, a, b):
"""Selects a pytree based on the given predicate."""
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar"
return jax.tree_util.tree_map(partial(jax.lax.select, pred), a, b)
def save_model(agent, save_path: str, iteration: int):
"""Save only the agent's state dict."""
os.makedirs(save_path, exist_ok=True)
model_file = os.path.join(save_path, f"model_iteration_{iteration}.pkl")
with open(model_file, "wb") as f:
state_dict = jax.device_get(agent.state_dict())
pickle.dump(state_dict, f)
print(f"Model saved at {model_file}")
def load_model(game_class: str, agent_class: str, load_path: str, iteration: int):
"""Load a model from disk by loading state dict into a freshly instantiated agent."""
model_file = os.path.join(load_path, f"model_iteration_{iteration}.pkl")
with open(model_file, "rb") as f:
state_dict = pickle.load(f)
env = import_class(game_class)()
agent = import_class(agent_class)(
input_dims=env.observation().shape,
num_actions=env.num_actions()
)
agent = agent.load_state_dict(state_dict)
print(f"Model loaded from {model_file}")
return agent
def make_directories(base_path, variant):
"""Prepare the directory structure for storing models."""
path = os.path.join(base_path, variant)
os.makedirs(path, exist_ok=True)
return path