This repo features a high-speed JAX implementation of the Proximal Policy Optimisation (PPO) algorithm. The algorithm is given as a single file implementation so that all the design choices are clear.
| Training Reward | Training Reward |
|---|---|
![]() |
![]() |
The algorithm is run by running the python script from the home directory. A custom config file can be given as follows,
python3 ppx/systems/ppo.py --config-name=ppo_MinAtar.yamlSince hydra is used for managing configurations, overide parameters can be passed as arguments to this command. The default parameters can be changes in the relevant config file.
The notebooks/ directory contains simple .ipynb files to proide basic plotting functions.
We recommend managing dependencies using a virtual environment, which can be installed with the following commands,
python3.9 -m venv venv
source venv/bin/activate
Install dependencies using the requirements.txt file:
pip install -r requirements.txt
The codebase is installed as a pip package with the following command:
pip install -e .
Note JAX must be separately installed for the specific device used. For straightforward CPU usage use,
pip install -U "jax[cpu]"
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
PPO-EWMA : a batch size-invariance algorithm that uses exponentially weighted moving averages to remove dependence on the batch-size hyperparameter.
- The next steps are tests with the learning rate adjustment, and advantage norm adjustment.
- The image below, shows the variance between batch sizes. The right image shows the current results using EWMA. The performance is slightly higher, but the variance greater.
| PPO | PPO-EWMA |
|---|---|
![]() |
![]() |
- Add an env wrapper to use the Jumanji style step method which returns a
stateandTimestep. - Add a KL diverange PPO algorithm
- Add tests with different learning rates. Include the effect of learning rate annealing.
The code is based on the format of Mava and is inspired from PureJaxRL and CleanRL.



