autopdex.implicit_diff.custom_root

autopdex.implicit_diff.custom_root(residual_fun: Callable, mat_fun: Callable, solve: Callable, free_dofs=None, has_aux: bool = False, mode='reverse', reference_signature: Callable | None = None)[source]

Decorator for adding implicit differentiation to a root solver.

Parameters:
  • residual_fun – A callable the returns the possibly nonlinear residual of which to find the root of, residual_fun(dofs, *args). The invariant is residual_fun(sol, *args) == 0 at the solution / root sol.

  • mat_fun – A callable that returns the sparse tangent matrix as a jax.experimental.BCOO with dofs and args as arguments. Can also be a pure callback.

  • solve – A linear solver of the form solve(mat[jax.experimental.BCOO], b[jnp.ndarray]).

  • free_dofs – For constraining certain degrees of freedom. In case free_dofs is not None, the second argument of the solver has to be a dictionary having the keys ‘dirichlet dofs’ and ‘dirichlet conditions’. The first one is a dictionary of jnp.ndarrays with the same structure as dofs, where the entries are boolean masks indicating the dofs that are constrained. The second one is a dictionary of jnp.ndarrays with the same structure as dofs, where the entries are the values of the constrained dofs (see source code of _root_jvp and _root_vjp for details or solver.adaptive_load_stepping for exemplary use).

  • has_aux – whether the decorated root solver function returns auxiliary data.

  • mode – The differentiation mode (‘forward’ or ‘reverse’/’backward’).

  • reference_signature – optional function signature (i.e. arguments and keyword arguments), with which the solver and optimality functions are expected to agree. Defaults to residual_fun. It is required that solver and optimality functions share the same input signature, but both might be defined in such a way that the signature correspondence is ambiguous (e.g. if both accept catch-all **kwargs). To satisfy custom_root’s requirement, any function with an unambiguous signature can be provided here.

Returns:

The decorated root solver function that is equipped a with custom vjp or jvp rule.

Example

See e.g. the implementation of autopdex.solver.adaptive_load_stepping.