autopdex.utility.dict_flatten

autopdex.utility.dict_flatten(arr)[source]

Recursively flattens a nested dict of arrays (np.ndarray or jnp.ndarray) to one flat array. If a single array is provided, it simply returns arr.flatten().

Parameters:

arr (dict or array) – A nested dictionary of arrays (NumPy or JAX) or a single array.

Returns:

A single flat array containing all elements, with the same type as the input arrays.

Return type:

array