Source code for autopdex.implicit_diff

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (C) 2024 Tobias Bode
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.


"""
These functions are a modification and extension of the functions in jaxopt._src.implicit_diff for
external solve functions with a jax.experimental.sparse.BCOO matrix as an argument instead of a matvec function.
The root_vjp and root_jvp functions were modified in a way that external solvers 
can be used via a pure_callback and constraints can be taken into account.
With the wrapper custom_root, root solvers can be made differentiable both in forward or reverse mode of arbitrary order.
Mixing the differentiation mode is currently not possible.
"""

import inspect
from typing import Any
from typing import Callable
from typing import Optional
from typing import Tuple

import jax
import jax.numpy as jnp

from autopdex import utility


def tree_scalar_mul(scalar, tree_x):
    """Compute scalar * tree_x."""
    return jax.tree_util.tree_map(lambda x: scalar * x, tree_x)

def _extract_kwargs(kwarg_keys, flat_args):
    n = len(flat_args) - len(kwarg_keys)
    args, kwarg_vals = flat_args[:n], flat_args[n:]
    kwargs = dict(zip(kwarg_keys, kwarg_vals))
    return args, kwargs

def _signature_bind(signature, *args, **kwargs):
    ba = signature.bind(*args, **kwargs)
    ba.apply_defaults()
    return ba.args, ba.kwargs

def _signature_bind_and_match(signature, *args, **kwargs):
    # We want to bind *args and **kwargs based on the provided
    # signature, but also to associate the resulting positional
    # arguments back. To achieve this, we lift arguments to a triple:
    #
    #   (was_kwarg, ref, value)
    #
    # where ref is an index position (int) if the original argument was
    # from *args and a dictionary key if the original argument was from
    # **kwargs. After binding to the inspected signature, we use the
    # tags to associate the resolved positional arguments back to their
    # arg and kwarg source.

    args = [(False, i, v) for i, v in enumerate(args)]
    kwargs = {k: (True, k, v) for (k, v) in kwargs.items()}
    ba = signature.bind(*args, **kwargs)

    mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in ba.args]

    def map_back(out_args):
        src_args = [None] * len(args)
        src_kwargs = {}
        for (was_kwarg, ref), out_arg in zip(mapping, out_args):
            if was_kwarg:
                src_kwargs[ref] = out_arg
            else:
                src_args[ref] = out_arg
        return src_args, src_kwargs

    out_args = tuple(v for _, _, v in ba.args)
    out_kwargs = {k: v for k, (_, _, v) in ba.kwargs.items()}
    return out_args, out_kwargs, map_back

def _jvp_args(residual_fun, sol, args, tangents):
    """JVP in the second argument of residual_fun."""
    # We close over the solution.
    fun = lambda *y: residual_fun(sol, *y)
    return jax.jvp(fun, args, tangents)[1]

def _root_vjp(
    residual_fun: Callable,
    mat_fun: Callable,
    sol: Any,
    args: Tuple,
    cotangent: Any,
    solve_fun: Callable,
    free_dofs: Any,
) -> Any:
    """Vector-Jacobian product of a root.

    The invariant is ``residual_fun(sol, *args) == 0``.

    Args:
      residual_fun: the optimality function to use.
      sol: solution / root (pytree).
      mat_fun: a function that has to compute the sparse tangent matrix with sol and args as arguments.
      args: tuple containing the arguments with respect to which we wish to
        differentiate ``sol`` against.
      cotangent: vector to left-multiply the Jacobian with
        (pytree, same structure as ``sol``).
      solve_fun: a linear solver of the form ``x = solve_fun(mat, b)``,
        where ``mat`` is as jax.experimental.sparse.BCOO matrix.
    Returns:
      tuple of the same length as ``len(args)`` containing the vjps w.r.t.
      each argument. Each ``vjps[i]`` has the same pytree structure as
      ``args[i]``.
    """

    free_dofs_flat = None
    if free_dofs is not None:
        assert "dirichlet conditions" in args[0], "'dirichlet conditions' \
        have to be defined in a dict as the second argument of the root solver function."
        dirichlet_dofs_flat = utility.dict_flatten(args[0]["dirichlet dofs"])
        free_dofs_flat = ~dirichlet_dofs_flat

    @jax.custom_vjp
    def linear_solver_fun_vjp(A, b, free_dofs_flat):
        return solve_fun(A, b, free_dofs_flat)  # Here can be an external callback

    def linear_solver_fun_vjp_fwd(A, b, free_dofs_flat):
        result = linear_solver_fun_vjp(A, b, free_dofs_flat)
        return result, (A, b, free_dofs_flat, result)

    def linear_solver_fun_vjp_bwd(res, g):
        A, b, free_dofs_flat, result = res

        # Ensure g matches the size of the expected gradient
        if free_dofs is not None:

            # Sparse outer product
            result_dot = linear_solver_fun_vjp(A.T, g, free_dofs_flat)

            size = A.shape[0]
            empty = jnp.zeros((size,), dtype=float)
            result_tmp = utility.mask_op(empty, free_dofs_flat, result, "set")
            result_dot_tmp = utility.mask_op(empty, free_dofs_flat, result_dot, "set")

            # # Use the sparsity structure of A to build Fx_dot
            indices = jnp.asarray(A.indices, dtype=int)
            data = -result_dot_tmp[indices[:, 0]] * result_tmp[indices[:, 1]]

            # Construct Fx_dot with the same sparsity pattern
            Fx_dot = jax.experimental.sparse.BCOO((data, indices), shape=A.shape)
            Fy_dot = result_dot
        else:
            # Sparse outer product
            result_dot = linear_solver_fun_vjp(A.T, g, None)

            # Use the sparsity structure of A to build Fx_dot
            indices = jnp.asarray(A.indices, dtype=int)
            data = -result_dot[indices[:, 0]] * result[indices[:, 1]]

            # Construct Fx_dot with the same sparsity pattern
            Fx_dot = jax.experimental.sparse.BCOO((data, indices), shape=A.shape)
            Fy_dot = result_dot

        # Correctly return the shapes that match the input arguments
        return (Fx_dot, Fy_dot, None)

    linear_solver_fun_vjp.defvjp(linear_solver_fun_vjp_fwd, linear_solver_fun_vjp_bwd)
    diffable_solve_fun = linear_solver_fun_vjp

    mat = mat_fun(sol, *args)

    # The solution of A^T u = v, where
    # A = jacobian(residual_fun, argnums=0)
    # v = -cotangent.
    v = tree_scalar_mul(-1, cotangent)

    if free_dofs is not None:
        dirichlet_dofs = utility.reshape_as(dirichlet_dofs_flat, sol)

        v_flat = utility.dict_flatten(v)
        u_f = diffable_solve_fun(mat.T, v_flat, free_dofs_flat)
        u_flat = utility.mask_op(
            utility.dict_flatten(utility.dict_zeros_like(v)), free_dofs_flat, u_f, "set"
        )

        def fun_args(*args):
            def residual_fun_tmp(sol, *args):
                dirichlet_values = args[0]["dirichlet conditions"]
                sol_with_bc = utility.mask_op(
                    sol, dirichlet_dofs, dirichlet_values, "set"
                )
                return residual_fun(sol_with_bc, *args)

            return residual_fun_tmp(sol, *args)

        u = utility.reshape_as(u_flat, v)
    else:

        def fun_args(*args):
            return residual_fun(sol, *args)

        v_flat = utility.dict_flatten(v)
        u = utility.reshape_as(diffable_solve_fun(mat.T, v_flat, None), v)

    _, vjp_fun_args = jax.vjp(fun_args, *args)
    args_vjp = vjp_fun_args(u)

    if free_dofs is not None:
        updated_args0 = args_vjp[0]
        tmpl = updated_args0["dirichlet conditions"]
        updated_args0["dirichlet conditions"] = utility.mask_op(
            updated_args0["dirichlet conditions"], utility.reshape_as(dirichlet_dofs, tmpl), utility.reshape_as(cotangent, tmpl), "add"
        )
        args_vjp = (updated_args0,) + args_vjp[1:]

    return args_vjp

def _root_jvp(
    residual_fun: Callable,
    mat_fun: Callable,
    sol: Any,
    args: Tuple,
    tangents: Tuple,
    solve_fun: Callable,
    free_dofs: Any,
    # dirichlet_dofs: Any,
) -> Any:
    """
    Jacobian-vector product of a root.

    The invariant is ``residual_fun(sol, *args) == 0``.

    Args:
      residual_fun: the optimality function to use.
      mat_fun: a function that has to compute the sparse tangent matrix with sol and args as arguments.
      sol: solution / root (pytree).
      args: tuple containing the arguments with respect to which to differentiate.
      tangents: a tuple of the same size as ``len(args)``. Each ``tangents[i]``
        has the same pytree structure as ``args[i]``.
      solve_fun: a linear solver of the form ``x = solve_fun(mat, b)``,
        where ``mat`` is as jax.experimental.sparse.BCOO matrix.
    Returns:
      a pytree with the same structure as ``sol``.
    """
    free_dofs_flat = None
    if free_dofs is not None:
        assert "dirichlet conditions" in args[0], "'dirichlet conditions' \
        have to be defined in a dict as the second argument of the root solver function."
        dirichlet_dofs_flat = utility.dict_flatten(args[0]["dirichlet dofs"])
        free_dofs_flat = ~dirichlet_dofs_flat

    # Compute tangent matrix
    A = mat_fun(sol, *args)
    mat_shape = A.shape

    # Forward differentiable sparse linear solver
    # TODO: register as primitive in order to allow mixed jacfwd/jacrev
    @jax.custom_jvp
    def linear_solver_fun_jvp(data, indices, b, free_dofs_flat):
        A = jax.experimental.sparse.BCOO((data, indices), shape=mat_shape)
        return solve_fun(A, b, free_dofs_flat)

    @linear_solver_fun_jvp.defjvp
    def linear_solver_fun_jvp_rule(primals, tangents):
        data, indices, b, free_dofs_flat = primals
        data_dot, _, b_dot, _ = tangents

        # Compute the primal result using the linear solver function
        primal_result = linear_solver_fun_jvp(data, indices, b, free_dofs_flat)

        # ToDo: is it somehow possible without A_dot via jvps?
        A_dot = jax.experimental.sparse.BCOO((data_dot, indices), shape=mat_shape)

        # Handle the tangent calculation
        if free_dofs is not None:
            primal_result_tmp = utility.mask_op(
                jnp.zeros((mat_shape[0],), dtype=float),
                free_dofs_flat,
                primal_result,
            )
            rhs = b_dot - (A_dot @ primal_result_tmp)
            result_dot = linear_solver_fun_jvp(data, indices, rhs, free_dofs_flat)
        else:
            result_dot = linear_solver_fun_jvp(
                data, indices, b_dot - A_dot @ primal_result, None
            )

        return primal_result, result_dot

    # Assign the jvp-enabled solver function
    solve_func = linear_solver_fun_jvp

    if free_dofs is not None:
        dirichlet_dofs = utility.reshape_as(dirichlet_dofs_flat, sol)

        # Explicit imposition of DOFs in order to be able to take derivatives w.r.t. nodally imposed DOFs
        def residual_fun_tmp(sol, *args):
            sol_with_bc = utility.mask_op(
                sol, dirichlet_dofs, args[0]["dirichlet conditions"], "set"
            )
            return residual_fun(sol_with_bc, *args)

        Bv = _jvp_args(residual_fun_tmp, sol, args, tangents)
        Bv_free = utility.dict_flatten(Bv)
        Jv_free = solve_func(A.data, A.indices, -Bv_free, free_dofs_flat)

        empty_flat = utility.dict_flatten(utility.dict_zeros_like(sol))
        Jv = utility.reshape_as(
            utility.mask_op(empty_flat, free_dofs_flat, Jv_free, "set"), sol
        )

        Jv = utility.mask_op(
            Jv, dirichlet_dofs, tangents[0]["dirichlet conditions"], "set"
        )

    else:
        Bv = _jvp_args(residual_fun, sol, args, tangents)
        Jv = utility.reshape_as(
            solve_func(A.data, A.indices, -utility.dict_flatten(Bv), None), Bv
        )

    return Jv

def _custom_root(
    solver_fun,
    residual_fun,
    mat_fun,
    free_dofs,
    solve,
    has_aux,
    mode="reverse",
    reference_signature=None,
):
    # When caling through `jax.custom_vjp`, jax attempts to resolve all
    # arguments passed by keyword to positions (this is in order to
    # match against a `nondiff_argnums` parameter that we do not use
    # here). It does so by resolving them according to the custom_jvp'ed
    # function's signature. It disallows functions defined with a
    # catch-all `**kwargs` expression, since their signature cannot
    # always resolve all keyword arguments to positions.
    #
    # We can loosen the constraint on the signature of `solver_fun` so
    # long as we resolve keywords to positions ourselves. We can do so
    # just in time, by flattening the `kwargs` dict (respecting its
    # iteration order) and supplying `custom_vjp` with a
    # positional-argument-only function. We then explicitly coordinate
    # flattening and unflattening around the `custom_vjp` boundary.
    #
    # Once we make it past the `custom_vjp` boundary, we do some more
    # work to align arguments with the reference signature (which is, by
    # default, the signature of `residual_fun`).

    solver_fun_signature = inspect.signature(solver_fun)

    if reference_signature is None:
        reference_signature = inspect.signature(residual_fun)

    elif not isinstance(reference_signature, inspect.Signature):
        # If is a CompositeLinearFunction, accesses subfun.
        # Otherwise, assumes a Callable.
        fun = getattr(reference_signature, "subfun", reference_signature)
        reference_signature = inspect.signature(fun)

    def make_custom_solver_fun(solver_fun, kwarg_keys):

        def solver_fun_flat_tmp(*flat_args):
            args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
            return solver_fun(*args, **kwargs)

        if mode == "reverse" or mode == "backward":
            solver_fun_flat = jax.custom_vjp(solver_fun_flat_tmp)
        elif mode == "forward":
            solver_fun_flat = jax.custom_jvp(solver_fun_flat_tmp)
        else:
            raise ValueError("Mode must be either 'forward' or 'reverse'.")

        # Forward-mode differentiation (JVP)
        def solver_fun_jvp(primals, tangents):
            args, kwargs = _extract_kwargs(kwarg_keys, primals)
            tangent_args, tangent_kwargs = _extract_kwargs(kwarg_keys, tangents)

            # Compute the primal solution using the root solver function
            primal_sol = solver_fun_flat(*args, **kwargs)

            # Handle has_aux case
            if has_aux:
                sol = primal_sol[0]
                aux_data = primal_sol[1:]

                # TODO: allow integer and boolean auxilary data
            else:
                sol = primal_sol

            # Compute JVP using root_jvp
            jvp_sol = _root_jvp(
                residual_fun=residual_fun,
                mat_fun=mat_fun,
                sol=sol,
                args=args[1:],  # Exclude the initial params from args
                tangents=tangent_args[1:],  # Exclude the initial params from tangents
                solve_fun=solve,
                free_dofs=free_dofs,
            )

            if has_aux:
                # Return primal and tangent for both solution and auxiliary data
                aux_tangent = jax.tree_util.tree_map(jnp.zeros_like, aux_data)
                return (sol,) + aux_data, (jvp_sol,) + aux_tangent
            else:
                return primal_sol, jvp_sol

        # Reverse-mode differentiation (VJP)
        def solver_fun_fwd(*flat_args):
            res = solver_fun_flat(*flat_args)
            return res, (res, flat_args)

        def solver_fun_bwd(tup, cotangent):
            res, flat_args = tup
            args, kwargs = _extract_kwargs(kwarg_keys, flat_args)

            # solver_fun can return auxiliary data if has_aux = True.
            if has_aux:
                cotangent = cotangent[0]
                sol = res[0]
            else:
                sol = res

            ba_args, ba_kwargs, map_back = _signature_bind_and_match(
                reference_signature, *args, **kwargs
            )
            if ba_kwargs:
                raise TypeError(
                    "keyword arguments to solver_fun could not be resolved to "
                    "positional arguments based on the signature "
                    f"{reference_signature}. This can happen under custom_root if "
                    "residual_fun takes catch-all **kwargs, or under "
                    "custom_fixed_point if fixed_point_fun takes catch-all **kwargs, "
                    "both of which are currently unsupported."
                )

            # Compute VJPs w.r.t. args.
            vjps = _root_vjp(
                residual_fun=residual_fun,
                mat_fun=mat_fun,
                sol=sol,
                args=ba_args[1:],
                cotangent=cotangent,
                solve_fun=solve,
                free_dofs=free_dofs,
            )
            # Prepend None as the vjp for init_params.
            vjps = (None,) + vjps

            arg_vjps, kws_vjps = map_back(vjps)
            ordered_vjps = tuple(arg_vjps) + tuple(kws_vjps[k] for k in kwargs.keys())
            return ordered_vjps

        if mode == "reverse" or mode == "backward":
            solver_fun_flat.defvjp(solver_fun_fwd, solver_fun_bwd)
        else:
            solver_fun_flat.defjvp(solver_fun_jvp)
        return solver_fun_flat

    def wrapped_solver_fun(*args, **kwargs):
        args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
        keys, vals = list(kwargs.keys()), list(kwargs.values())
        return make_custom_solver_fun(solver_fun, keys)(*args, *vals)

    return wrapped_solver_fun

[docs]def custom_root( residual_fun: Callable, mat_fun: Callable, solve: Callable, free_dofs = None, has_aux: bool = False, mode="reverse", reference_signature: Optional[Callable] = None, ): """Decorator for adding implicit differentiation to a root solver. Args: residual_fun: A callable the returns the possibly nonlinear residual of which to find the root of, ``residual_fun(dofs, *args)``. The invariant is ``residual_fun(sol, *args) == 0`` at the solution / root ``sol``. mat_fun: A callable that returns the sparse tangent matrix as a jax.experimental.BCOO with dofs and args as arguments. Can also be a pure callback. solve: A linear solver of the form ``solve(mat[jax.experimental.BCOO], b[jnp.ndarray])``. free_dofs: For constraining certain degrees of freedom. In case free_dofs is not None, the second argument of the solver has to be a dictionary having the keys 'dirichlet dofs' and 'dirichlet conditions'. The first one is a dictionary of jnp.ndarrays with the same structure as dofs, where the entries are boolean masks indicating the dofs that are constrained. The second one is a dictionary of jnp.ndarrays with the same structure as dofs, where the entries are the values of the constrained dofs (see source code of _root_jvp and _root_vjp for details or solver.adaptive_load_stepping for exemplary use). has_aux: whether the decorated root solver function returns auxiliary data. mode: The differentiation mode ('forward' or 'reverse'/'backward'). reference_signature: optional function signature (i.e. arguments and keyword arguments), with which the solver and optimality functions are expected to agree. Defaults to ``residual_fun``. It is required that solver and optimality functions share the same input signature, but both might be defined in such a way that the signature correspondence is ambiguous (e.g. if both accept catch-all ``**kwargs``). To satisfy custom_root's requirement, any function with an unambiguous signature can be provided here. Returns: The decorated root solver function that is equipped a with custom vjp or jvp rule. Example: See e.g. the implementation of autopdex.solver.adaptive_load_stepping. """ def wrapper(solver_fun): return _custom_root( solver_fun, residual_fun, mat_fun, free_dofs, solve, has_aux, mode, reference_signature, ) return wrapper