The Wire · Showcase
JAX TIGHTENS GRADIENT CHECKPOINTING, FIXES SCIPY PARITY BUGS
By RepoJournal · Filed · About Google
JAX's custom_vjp3 primitive now supports checkpoint_name in forward passes, unlocking memory-efficient gradient computation for complex AD workflows.
The big move is giving custom_vjp3 a remat rule that applies JAX's remat_transform to primal functions [1]. This means you can now checkpoint intermediate values in custom VJP forward passes without breaking the AD chain, which matters for training large models where memory is the bottleneck. Alongside this, JAX made primal_left_tangent_right a hijax primitive so it lowers away in lojax, letting the compiler dead-code-eliminate unused gradients [1].
The team also shipped fixes to scipy.special parity bugs that have been lingering [2] [3]. lax.special.gammainc and lax.special.gammaincc now return 0.0 and 1.0 respectively when a=infinity instead of NaN, matching scipy's behavior. Same fix went into igamma and igammac for a=inf with finite x. These are edge cases but they matter if you're doing numerical computation or porting scipy code to JAX.
Test coverage for remat3 got a refresh to close gaps [4], so expect fewer surprises in production rematerialization rules. Custom AD primitives and scipy compatibility are both stability plays, not headline features, but both are the kind of fixes that prevent silent correctness bugs.
Action items
References
- [1] support checkpoint_name in custom_vjp fwd if custom_vjp3=1 ↗ google/jax
- [2] Fix `lax.special.gammainc` and `lax.special.gammaincc` for a = infinity ↗ google/jax
- [3] Fix `lax.special.igamma` and `lax.special.igammac` for `a = inf` & `x` finite. google/jax
- [4] [remat3] update tests to finish coverage of remat3 ↗ google/jax
FAQ
- What changed in Google on June 28, 2026?
- JAX's custom_vjp3 primitive now supports checkpoint_name in forward passes, unlocking memory-efficient gradient computation for complex AD workflows.
- What should Google teams do about it?
- Review custom_vjp forward passes for checkpoint_name opportunities if you're memory-constrained on large models • Test scipy.special edge cases (gammainc, gammaincc, igamma, igammac) if you rely on infinity bounds
- Which Google repositories shipped on June 28, 2026?
- google/jax