autopdex.utility.jacfwd_upto_n_one_vector_arg

autopdex.utility.jacfwd_upto_n_one_vector_arg(fun, x, n)[source]

Computes up to the nth derivative of a function fun with respect to a vector argument x using a single pass.

This function is tailored for sensitivity analysis, where the primal evaluation can be very expensive. By using recursive Jacobian-vector products (JVPs), this function computes all derivatives in one pass, avoiding redundant computations.

The function supports computing derivatives of any order with respect to a single jnp.ndarray argument x. The derivatives are returned as a tuple, where each entry corresponds to the function’s value or a derivative of increasing order.

Parameters:
  • fun (Callable) – The function for which derivatives are to be computed. This function should accept a single jnp.ndarray argument x and return a scalar output or jnp.ndarray.

  • x (jnp.ndarray) – The vector input at which the derivatives are evaluated.

  • n (int) – The order of the derivatives to compute. For instance, n=3 computes the value, first, second, and third derivatives.

Returns:

A tuple containing the function value and its derivatives up to the nth order. The structure of the tuple is as follows: - The first entry is the value of the function fun(x). - Subsequent entries are the first, second, and higher-order derivatives. - For a function returning a multi-dimensional array, the array dimensions are appended to derivative dimensions (see example below)

Return type:

tuple

Example

>>> def example_fun(x): # input shape (2,)
>>>     return jnp.asarray([
>>>         [x[0], x[1], x[0] + x[1], x[0] * x[1]],
>>>         [x[1] - x[0], x[0] / (x[1] + 1e-5), x[0]**2, x[1]**2],
>>>         [x[0] + 1, x[1] + 1, x[0] * x[0], x[1] * x[1]]
>>>     ]) # output shape (3, 4)
>>>
>>> x = jnp.asarray([1.0, 2.0])  # Example vector input
>>> n = 3  # Up to the third derivative
>>> derivatives = jacfwd_upto_n_one_vector_arg(example_fun, x, n)
>>>
>>> print("Function value:", derivatives[0].shape) # shape (3, 4)
>>> print("First derivative:", derivatives[1].shape)  # shape (2, 3, 4)
>>> print("Second derivative:", derivatives[2].shape) # shape (2, 2, 3, 4)
>>> print("Third derivative:", derivatives[3].shape) # shape (2, 2, 2, 3, 4)