autopdex.implicit_diff.custom_root
- autopdex.implicit_diff.custom_root(residual_fun: Callable, mat_fun: Callable, solve: Callable, free_dofs=None, has_aux: bool = False, mode='reverse', reference_signature: Callable | None = None)[source]
Decorator for adding implicit differentiation to a root solver.
- Parameters:
residual_fun – A callable the returns the possibly nonlinear residual of which to find the root of,
residual_fun(dofs, *args)
. The invariant isresidual_fun(sol, *args) == 0
at the solution / rootsol
.mat_fun – A callable that returns the sparse tangent matrix as a jax.experimental.BCOO with dofs and args as arguments. Can also be a pure callback.
solve – A linear solver of the form
solve(mat[jax.experimental.BCOO], b[jnp.ndarray])
.free_dofs – For constraining certain degrees of freedom. In case free_dofs is not None, the second argument of the solver has to be a dictionary having the keys ‘dirichlet dofs’ and ‘dirichlet conditions’. The first one is a dictionary of jnp.ndarrays with the same structure as dofs, where the entries are boolean masks indicating the dofs that are constrained. The second one is a dictionary of jnp.ndarrays with the same structure as dofs, where the entries are the values of the constrained dofs (see source code of _root_jvp and _root_vjp for details or solver.adaptive_load_stepping for exemplary use).
has_aux – whether the decorated root solver function returns auxiliary data.
mode – The differentiation mode (‘forward’ or ‘reverse’/’backward’).
reference_signature – optional function signature (i.e. arguments and keyword arguments), with which the solver and optimality functions are expected to agree. Defaults to
residual_fun
. It is required that solver and optimality functions share the same input signature, but both might be defined in such a way that the signature correspondence is ambiguous (e.g. if both accept catch-all**kwargs
). To satisfy custom_root’s requirement, any function with an unambiguous signature can be provided here.
- Returns:
The decorated root solver function that is equipped a with custom vjp or jvp rule.
Example
See e.g. the implementation of autopdex.solver.adaptive_load_stepping.