Implicit_diff
These functions are a modification and extension of the functions in jaxopt._src.implicit_diff for external solve functions with a jax.experimental.sparse.BCOO matrix as an argument instead of a matvec function. The root_vjp and root_jvp functions were modified in a way that external solvers can be used via a pure_callback and constraints can be taken into account. With the wrapper custom_root, root solvers can be made differentiable both in forward or reverse mode of arbitrary order. Mixing the differentiation mode is currently not possible.
For an example, see uncertainty_estimation_hyperelastic.py.
Implicit differentiation
|
Decorator for adding implicit differentiation to a root solver. |