diff --git a/acme/agents/jax/ars/builder.py b/acme/agents/jax/ars/builder.py index 27b715bdb1..a397d3ba79 100644 --- a/acme/agents/jax/ars/builder.py +++ b/acme/agents/jax/ars/builder.py @@ -25,6 +25,7 @@ from acme.agents.jax import builders from acme.agents.jax.ars import config as ars_config from acme.agents.jax.ars import learning +from acme.agents.jax.ars import networks as ars_networks from acme.jax import networks as networks_lib from acme.jax import running_statistics from acme.jax import utils @@ -146,6 +147,15 @@ def make_dataset_iterator( max_in_flight_samples_per_worker=1) return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + def make_policy( + self, + networks: networks_lib.FeedForwardNetwork, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> Tuple[str, networks_lib.FeedForwardNetwork]: + del environment_spec + return ars_networks.make_policy_network(networks, eval_mode=evaluation) + def make_adder( self, replay_client: reverb.Client, environment_spec: Optional[specs.EnvironmentSpec], diff --git a/acme/agents/jax/sac/agent_test.py b/acme/agents/jax/sac/agent_test.py new file mode 100644 index 0000000000..afb556a7cf --- /dev/null +++ b/acme/agents/jax/sac/agent_test.py @@ -0,0 +1,59 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the SAC agent.""" + +from acme import specs +from acme.agents.jax import sac +from acme.testing import fakes +import jax +import optax + +from absl.testing import absltest +from absl.testing import parameterized + + +class SACTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('adaptive_entropy', None, 0.), + ('fixed_entropy', 0.2, 0.)) + def test_train(self, entropy_coefficient, target_entropy): + seed = 0 + num_iterations = 5 + batch_size = 64 + + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6) + spec = specs.make_environment_spec(environment) + + networks = sac.make_networks(spec, hidden_layer_sizes=(8, 8)) + dataset = fakes.transition_iterator(environment) + key = jax.random.PRNGKey(seed) + learner = sac.SACLearner( + networks=networks, + rng=key, + iterator=dataset(batch_size), + policy_optimizer=optax.adam(3e-4), + q_optimizer=optax.adam(3e-4), + entropy_coefficient=entropy_coefficient, + target_entropy=target_entropy, + num_sgd_steps_per_step=1) + + for _ in range(num_iterations): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/agents/jax/td3/agent_test.py b/acme/agents/jax/td3/agent_test.py new file mode 100644 index 0000000000..0fd7a81ab1 --- /dev/null +++ b/acme/agents/jax/td3/agent_test.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the TD3 agent.""" + +from acme import specs +from acme.agents.jax import td3 +from acme.testing import fakes +import jax +import optax + +from absl.testing import absltest +from absl.testing import parameterized + + +class TD3Test(parameterized.TestCase): + + @parameterized.named_parameters( + ('standard', None), + ('with_bc', 2.5)) + def test_train(self, bc_alpha): + seed = 0 + num_iterations = 5 + batch_size = 64 + + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6) + spec = specs.make_environment_spec(environment) + + networks = td3.make_networks(spec, hidden_layer_sizes=(8, 8)) + dataset = fakes.transition_iterator(environment) + key = jax.random.PRNGKey(seed) + learner = td3.TD3Learner( + networks=networks, + random_key=key, + discount=0.99, + iterator=dataset(batch_size), + policy_optimizer=optax.adam(3e-4), + critic_optimizer=optax.adam(3e-4), + twin_critic_optimizer=optax.adam(3e-4), + bc_alpha=bc_alpha, + num_sgd_steps_per_step=1) + + for _ in range(num_iterations): + learner.step() + + +if __name__ == '__main__': + absltest.main()