Gitpedia

Enzyme JAX

Custom Bindings for Enzyme Automatic Differentiation Tool and Interfacing with JAX.

From EnzymeAD·Updated June 1, 2026·View on GitHub·

Enzyme-JAX is a C++ project whose original aim was to integrate the Enzyme automatic differentiation tool [1] with JAX, enabling automatic differentiation of external C++ code within JAX. It has since expanded to incorporate Polygeist's [2] high performance raising, parallelization, cross compilation workflow, as well as numerous tensor, linear algerba, and communication optimizations. The project uses LLVM's MLIR framework for intermediate representation and transformation of code. As Enzyme is... The project is written primarily in MLIR, distributed under the Other license, first published in 2023.

Latest release: v0.0.12
June 1, 2026View Changelog →

Enzyme-JAX

Enzyme-JAX is a C++ project whose original aim was to integrate the Enzyme automatic differentiation tool [1] with JAX, enabling automatic differentiation of external C++ code within JAX. It has since expanded to incorporate Polygeist's [2] high performance raising, parallelization, cross compilation workflow, as well as numerous tensor, linear algerba, and communication optimizations. The project uses LLVM's MLIR framework for intermediate representation and transformation of code. As Enzyme is language-agnostic, this can be extended for arbitrary programming
languages (Julia, Swift, Fortran, Rust, and even Python)!

You can use

python
from enzyme_ad.jax import cpp_call # Forward-mode C++ AD example @jax.jit def something(inp): y = cpp_call(inp, out_shapes=[jax.core.ShapedArray([2, 3], jnp.float32)], source=""" template<std::size_t N, std::size_t M> void myfn(enzyme::tensor<float, N, M>& out0, const enzyme::tensor<float, N, M>& in0) { out0 = 56.0f + in0(0, 0); } """, fn="myfn") return y ones = jnp.ones((2, 3), jnp.float32) primals, tangents = jax.jvp(something, (ones,), (ones,) ) # Reverse-mode C++ AD example primals, f_vjp = jax.vjp(something, ones) (grads,) = f_vjp((x,))

Installation

The easiest way to install is using pip.

bash
# The project is available on PyPi and installable like # a usual python package (https://pypi.org/project/enzyme-ad/) pip install enzyme-ad

Building from source

Requirements: bazel-6.5, clang++, python, python-virtualenv,
python3-dev.

Build our extension with:

sh
# Will create a whl in bazel-bin/enzyme_ad-VERSION-SYSTEM.whl bazel build :wheel

Finally, install the built library with:

sh
pip install bazel-bin/enzyme_ad-VERSION-SYSTEM.whl

Note that you cannot run code from the root of the git directory. For instance, in the code below, you have to first run cd test before running test.py.

Running the test

To run tests, you can simply execute the following bazel commands (this does not require building or installing the wheel).

sh
bazel test //test/...

Alternatively, if you have installed the wheel, you can manually invoke the tests as follows

sh
cd test && python test.py

LSP Support

Enzyme-Jax exposes a bunch of different tensor rewrites as MLIR passes in src/enzyme_ad/jax/Passes. If you want to enable LSP support when working with this code, we recommend that you generate a compile_commands.json by running

bash
bazel run :refresh_compile_commands

References

[1] Moses, William, and Valentin Churavy. "Instead of rewriting foreign code for machine learning, automatically synthesize fast gradients." Advances in neural information processing systems 33 (2020): 12472-12485.

[2] Moses, William S., et al. "Polygeist: Raising C to polyhedral MLIR." 2021 30th International Conference on Parallel Architectures and Compilation Techniques (PACT). IEEE, 2021.

Contributors

Showing top 12 contributors by commit count.

View all contributors on GitHub →

This article is auto-generated from EnzymeAD/Enzyme-JAX via the GitHub API.Last fetched: 6/1/2026