Utility

This module contains some useful functions, including:

  • Wrapper to nearest neighbor algorithm (e.g. for moving least squares)

  • Degree of freedom (DOF) selection for boundary conditions

  • Compute condition number and check symmetry of tangent matrix

  • Functions for manipulating arrays

jit_with_docstring([static_argnames, ...])

JIT wrapper that preserves the original docstring of the function and additionally treats arguments from possibly_static_argnames as static if their value is callable.

dict_zeros_like(arr, **keyargs)

Wrapper around zeros_like, that works also for dicts with jnp.ndarray entries.

dict_flatten(arr)

Recursively flattens a nested dict of arrays (np.ndarray or jnp.ndarray) to one flat array.

reshape_as(flat_array, signature_array)

Reshapes a flat array (np.ndarray or jnp.ndarray) into an array or dict of arrays matching the structure of signature_array.

mask_set(array, selection, values)

Changes values in a JAX array or dict of arrays based on a boolean mask.

mask_select(array, selection)

Selects elements from a JAX array or dictionary of arrays based on a boolean mask.

mask_op(array, selection[, values, mode, ufunc])

Performs an operation on a JAX array or dict of arrays based on a boolean mask.

search_neighborhood(x_nodes, x_query, ...)

Neighbor search within a radius based on scipy's KDTree.

get_condition_number(dofs, settings, ...)

Computes the condition number of the assembled tangent matrix.

symmetry_check(dofs, settings, static_settings)

Checks the symmetry of the assembled tangent matrix.

dof_select(dirichlet_nodes, selected_fields)

DOF selection for the nodal imposition of boundary conditions (jitted).

jnp_to_tuple(jnp_array)

Converts a JAX array to a tuple.

to_jax_function(subexpr, reduced_expr)

Generation of a pure JAX function based on subexpressions and reduced return-value.

jacfwd_upto_n_scalar_args(fun, args, ...)

Computes all up to the nth derivative of a function f using a single pass.

jacfwd_upto_n_one_vector_arg(fun, x, n)

Computes up to the nth derivative of a function fun with respect to a vector argument x using a single pass.