autopdex.utility.jnp_to_tuple

autopdex.utility.jnp_to_tuple(jnp_array)[source]

Converts a JAX array to a tuple. Also works for dicts of JAX arrays.

Parameters:

jnp.ndarray – JAX array to convert.

Returns:

Converted tuple or flax.core.FrozenDict of tuples.

Return type:

tuple