Optimistix
Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Optimistix is a [JAX](https://github.com/google/jax) library for nonlinear solvers: root finding, minimisation, fixed points, and least squares. The project is written primarily in Python, distributed under the Apache License 2.0 license, first published in 2023. Key topics include: deep-learning, equinox, jax, neural-networks, optimisation.
Optimistix is a JAX library for nonlinear solvers: root finding, minimisation, fixed points, and least squares.
Features include:
- interoperable solvers: e.g. autoconvert root find problems to least squares problems, then solve using a minimisation algorithm.
- modular optimisers: e.g. use a BFGS quadratic bowl with a dogleg descent path with a trust region update.
- using a PyTree as the state.
- fast compilation and runtimes.
- interoperability with Optax.
- all the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support etc.
Installation
bashpip install optimistix
Requires Python 3.11+.
Documentation
Available at https://docs.kidger.site/optimistix.
Quick example
pythonimport jax.numpy as jnp import optimistix as optx # Let's solve the ODE dy/dt=tanh(y(t)) with the implicit Euler method. # We need to find y1 s.t. y1 = y0 + tanh(y1)dt. y0 = jnp.array(1.) dt = jnp.array(0.1) def fn(y, args): return y0 + jnp.tanh(y) * dt solver = optx.Newton(rtol=1e-5, atol=1e-5) sol = optx.fixed_point(fn, solver, y0) y1 = sol.value # satisfies y1 == fn(y1)
Citation
If you found this library to be useful in academic work, then please cite: (arXiv link)
bibtex@article{optimistix2024, title={Optimistix: modular optimisation in JAX and Equinox}, author={Jason Rader and Terry Lyons and Patrick Kidger}, journal={arXiv:2402.09983}, year={2024}, }
See also: other libraries in the JAX ecosystem
Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.
Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
paramax: parameterizations and constraints for PyTrees.
Scientific computing
Diffrax: numerical differential equation solvers.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)
Awesome JAX
Awesome JAX: a longer list of other JAX projects.
Credit
Optimistix was primarily built by Jason Rader (@packquickly): Twitter; GitHub; Website. It is being co-maintained by Johanna Haffner (@johannahaffner): GitHub; Website.
Contributors
Showing top 12 contributors by commit count.
