Sac n jax
Single-file SAC-N implementation on jax with flax and equinox. 10x faster than pytorch
Single-file SAC-N [1] implementation on jax with both flax and equinox. 10x faster than SAC-N on pytorch from CORL [2]. The project is written primarily in Python, distributed under the MIT License license, first published in 2022. Key topics include: d4rl, equinox, flax, jax, offline-reinforcement-learning.
SAC with Q-Ensemble for Offline RL
Single-file SAC-N [1] implementation on jax with both flax and equinox. 10x faster than SAC-N on pytorch from CORL [2].
And still easy to use and understand! To run:
bashpython sac_n_jax_flax.py --env_name="halfcheetah-medium-v2" --num_critics=10 --batch_size=256 python sac_n_jax_eqx.py --env_name="halfcheetah-medium-v2" --num_critics=10 --batch_size=256
Optionally, you can pass --config_path to the yaml file, for more see pyrallis docs.
Speed comparison
Main insight here is to jit epoch loop also with jax.lax.fori_loop or jax.lax.scan, not just one update of the networks, as it is usually done (jaxrl2 for instance). With jitting the update only speedup will be approx 1.5x here.
Both runs were trained on same V100 GPU.


References
Contributors
Showing top 1 contributor by commit count.
