autopdex.utility.jacfwd_upto_n_one_vector_arg
- autopdex.utility.jacfwd_upto_n_one_vector_arg(fun, x, n)[source]
Computes up to the nth derivative of a function fun with respect to a vector argument x using a single pass.
This function is tailored for sensitivity analysis, where the primal evaluation can be very expensive. By using recursive Jacobian-vector products (JVPs), this function computes all derivatives in one pass, avoiding redundant computations.
The function supports computing derivatives of any order with respect to a single jnp.ndarray argument x. The derivatives are returned as a tuple, where each entry corresponds to the function’s value or a derivative of increasing order.
- Parameters:
fun (Callable) – The function for which derivatives are to be computed. This function should accept a single jnp.ndarray argument x and return a scalar output or jnp.ndarray.
x (jnp.ndarray) – The vector input at which the derivatives are evaluated.
n (int) – The order of the derivatives to compute. For instance, n=3 computes the value, first, second, and third derivatives.
- Returns:
A tuple containing the function value and its derivatives up to the nth order. The structure of the tuple is as follows: - The first entry is the value of the function fun(x). - Subsequent entries are the first, second, and higher-order derivatives. - For a function returning a multi-dimensional array, the array dimensions are appended to derivative dimensions (see example below)
- Return type:
tuple
Example
>>> def example_fun(x): # input shape (2,) >>> return jnp.asarray([ >>> [x[0], x[1], x[0] + x[1], x[0] * x[1]], >>> [x[1] - x[0], x[0] / (x[1] + 1e-5), x[0]**2, x[1]**2], >>> [x[0] + 1, x[1] + 1, x[0] * x[0], x[1] * x[1]] >>> ]) # output shape (3, 4) >>> >>> x = jnp.asarray([1.0, 2.0]) # Example vector input >>> n = 3 # Up to the third derivative >>> derivatives = jacfwd_upto_n_one_vector_arg(example_fun, x, n) >>> >>> print("Function value:", derivatives[0].shape) # shape (3, 4) >>> print("First derivative:", derivatives[1].shape) # shape (2, 3, 4) >>> print("Second derivative:", derivatives[2].shape) # shape (2, 2, 3, 4) >>> print("Third derivative:", derivatives[3].shape) # shape (2, 2, 2, 3, 4)