autopdex.utility.jacfwd_upto_n_scalar_args
- autopdex.utility.jacfwd_upto_n_scalar_args(fun, args, derivative_order, argnum)[source]
Computes all up to the nth derivative of a function f using a single pass.
The function is designed for the use in sensitivity analysis, where the primal evaluation can be very expensive. By using recursive jvps, the function computes all the derivatives in pass and does not rely on XLA to delete duplicated code.
This function supports computing derivatives of any order with respect to individual or multiple scalar arguments in a tuple. The derivatives are returned either as a single value or as a collection of values, depending on the input type of argnum.
- Parameters:
fun (Callable) – The function for which derivatives are to be computed. This function should accept a sequence of scalar inputs and return a scalar output or jnp.ndarray.
args (tuple) – A tuple of input values at which the derivatives are evaluated. The length of args should match the number of arguments expected by fun.
derivative_order (int) – The order of the derivatives to compute. For instance, n=2 computes the value, first and second derivative.
argnum (Union[int, tuple]) – The index or indices of the argument(s) with respect to which the derivative is computed. If an integer is provided, the derivative is computed with respect to a single argument. If a tuple of integers is provided, the derivative is computed with respect to multiple arguments.
- Returns:
The jnp.ndarray has the following order: results[i][j][k]
i (optional) selects the argument in the argnum tuple. This dimension dissapears in case argnum is a scalar
j is the j-th derivative order
k an further indices come from the output shape of the example function
- Return type:
jnp.ndarray
Example
>>> def example_function(x, y): >>> return jnp.asarray([x**3, 2*x*y**3, x*y]) >>> >>> args = (2.0, 3.0) # Example input values for x and y >>> order = 4 # Up to the fourth derivative >>> argnum = (0, 1) # Differentiate with respect to both x and y >>> result = jacfwd_upto_n_scalar_args(example_function, args, order, argnum) >>> print("Derivatives:", result) >>> print("Derivatives:", result.shape) # Output: (2, 5, 3), like the dimensions of (argnum, 1+order, output_of_example_fun)