Source code for autopdex.assembler

# assembler.py
# Copyright (C) 2024 Tobias Bode
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

"""
Module for assembling and integrating functionals, residuals and tangents, supporting both 
'sparse' and 'dense' modes and 'user potentials', 'user residuals' and 'user elements'.

The assembly modes 'dense' and 'sparse' use JAX's automatic differentiation capabilities to 
determine the global residual and the tangent for the given values of the degrees of freedom 
and the models and variational schemes specified in static_settings.
In the 'dense' mode, the tangent matrix is returned as a jnp.ndarray, while in the 'sparse' 
mode, it is returned as a jax.experimental.BCOO matrix with duplicates. 
In the 'user potential', 'user residual' and 'user element' modes, the global residual and 
tangent can be assembled based on user-defined element-wise contributions. 
Depending on the execution location, the entries are calculated on CPU or GPU. 
For an efficient calculation, JAX's automatic vectorization transformation (vmap) is used.
The summation of duplicates is then carried out within the solver module. Currently, 
SciPy is used for this on the CPU.
"""

from functools import partial

import jax.numpy as jnp
import jax
from jax import vmap, jacrev, jacfwd, hessian, jvp, linearize, vjp, custom_jvp
from jax.tree import map as treemap
from jax.experimental import sparse

from autopdex import variational_schemes
from autopdex.utility import jit_with_docstring, dict_zeros_like, dict_flatten, reshape_as


# TODO: change vmaps to _batched_map in order to reduce memory consumption


## Helper functions
def _get_indices(connectivity, dofs):
    """
    Constructs the global indices for the assembly of the tangent matrix.

    Args:
        connectivity (array or dict): Connectivity array or dictionary of connectivity arrays.
        dofs (array or dict): DOFs array or dictionary of DOFs.

    Returns:
        indices (jnp.ndarray): Array of indices for the sparse matrix.
    """
    if callable(dofs):
        dofs = dofs(0.)

    if isinstance(dofs, dict):
        keys = dofs.keys()

        # Field-Offsets
        field_offsets = {}
        current_offset = 0
        for field in keys:
            field_offsets[field] = current_offset
            field_size = dofs[field].size
            current_offset += field_size

        indices_list = []

        num_elems = connectivity[next(iter(keys))].shape[0]
        elem_indices = jnp.arange(num_elems, dtype=int)

        for field_i in keys:
            for field_j in keys:

                def one_elem_indices(elem_idx):
                    # Global DOFs for field_i
                    conn_i = connectivity[field_i][elem_idx]
                    if dofs[field_i].ndim == 1:
                        dofs_per_node_i = 1
                    else:
                        dofs_per_node_i = dofs[field_i].shape[-1]
                    field_offset_i = field_offsets[field_i]
                    dof_local_i = jnp.arange(dofs_per_node_i, dtype=int)
                    dof_indices_i = (
                        field_offset_i + conn_i[:, None] * dofs_per_node_i + dof_local_i
                    )
                    global_dofs_i = jnp.asarray(dof_indices_i, dtype=int).flatten()

                    # Global DOFs for field_j
                    conn_j = connectivity[field_j][elem_idx]
                    if dofs[field_j].ndim == 1:
                        dofs_per_node_j = 1
                    else:
                        dofs_per_node_j = dofs[field_j].shape[-1]

                    field_offset_j = field_offsets[field_j]
                    dof_local_j = jnp.arange(dofs_per_node_j, dtype=int)
                    dof_indices_j = (
                        field_offset_j + conn_j[:, None] * dofs_per_node_j + dof_local_j
                    )
                    global_dofs_j = jnp.asarray(dof_indices_j, dtype=int).flatten()

                    # Generate indices
                    row_indices = jnp.repeat(global_dofs_i, global_dofs_j.size)
                    col_indices = jnp.tile(global_dofs_j, global_dofs_i.size)
                    indices = jnp.stack([row_indices, col_indices], axis=-1)
                    return indices

                # Vectorize over elements
                all_elem_indices = vmap(one_elem_indices)(elem_indices)
                indices = all_elem_indices.reshape(-1, 2)
                indices_list.append(indices)

        # Concatenate all indices
        indices = jnp.concatenate(indices_list, axis=0)
        return indices
    else:
        # dofs is array
        if dofs.ndim == 1:
            dofs_per_node = 1
        else:
            dofs_per_node = dofs.shape[-1]

        def one_elem_idx(neighb):
            global_dofs = neighb[:, None] * dofs_per_node + jnp.arange(dofs_per_node)
            global_dofs = global_dofs.flatten()
            n_dofs_element = global_dofs.size

            row_indices = jnp.repeat(global_dofs, n_dofs_element)
            col_indices = jnp.tile(global_dofs, n_dofs_element)
            indices = jnp.stack([row_indices, col_indices], axis=-1)
            return indices.astype(int)

        all_elem_indices = vmap(one_elem_idx)(connectivity)
        indices = all_elem_indices.reshape(-1, 2)
        return indices

def _get_element_quantities(dofs, settings, static_settings, set):
    """
    Extracts element-dependent quantities for the specified set.

    Args:
        dofs (jnp.ndarray or dict or callable): Degrees of freedom. Can be a function of time for transient problems.
        settings (dict): Settings dictionary.
        static_settings (dict or flax.core.FrozenDict): Static settings dictionary.
        set (int): The domain number.

    Returns:
        tuple: (model_fun, local_dofs, local_node_coor, elem_numbers, connectivity)
    """
    model_fun = static_settings["model"][set]
    x_nodes = settings["node coordinates"]
    dofs_is_fun = True if callable(dofs) else False
    if dofs_is_fun:
        dofs_is_dict = True if isinstance(dofs(0.), dict) else False
    else:
        dofs_is_dict = True if isinstance(dofs, dict) else False

    # Warning if it was defined in static_settings
    assert "connectivity" not in static_settings, \
        "'connectivity' has been moved to 'settings' in order to reduce compile time. \
        Further, you should not transform it to a tuple of tuples anymore."

    connectivity = settings["connectivity"][set]

    if dofs_is_dict:
        assert isinstance(
            x_nodes, dict
        ), "If 'dofs' is a dict, 'settings['node coordinates']' must also be a dict."
        assert isinstance(
            connectivity, dict
        ), "If 'dofs' is a dict, 'settings['connectivity'][set]' must also be a dict."

        elem_numbers = jnp.arange(connectivity[next(iter(connectivity))].shape[0])
    else:
        elem_numbers = jnp.arange(connectivity.shape[0])

    return model_fun, x_nodes, elem_numbers, connectivity

def _get_element_quantities_2(dofs, settings, static_settings, set):
    if callable(dofs):
        dofs = dofs(0.)
    assert isinstance(
        dofs, jnp.ndarray
    ), "Variational schemes do currently not support dofs as dicts."

    connectivity = settings["connectivity"][set]
    variational_scheme = static_settings["variational scheme"][set]
    x_int = settings["integration coordinates"][set]
    w_int = settings["integration weights"][set]
    int_point_numbers = jnp.arange(0, x_int.shape[0], 1)
    return connectivity, variational_scheme, x_int, w_int, int_point_numbers

def _extract_local_dofs_and_coor(dofs, node_list, x_nodes):
    # If DOFs are a function of time (for transient problems, forward them as a function of time)
    if callable(dofs):
        local_dofs = lambda t: treemap(lambda x, y: x.at[y].get(), dofs(t), node_list)
        # local_dofs = _make_elem_dofs_fun(dofs, node_list)
    else:
        local_dofs = treemap(lambda x, y: x.at[y].get(), dofs, node_list)
    local_node_coor = treemap(lambda x, y: x.at[y].get(), x_nodes, node_list)
    return local_dofs, local_node_coor

def _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity):
    x_i = x_int[int_point_number]
    w_i = w_int[int_point_number]
    if callable(dofs):
        local_dofs = lambda t: treemap(lambda x, y: x.at[y].get(), dofs(t), connectivity[int_point_number])
        # local_dofs = _make_elem_dofs_fun(dofs, connectivity[int_point_number])
    else:
        local_dofs = treemap(lambda x, y: x.at[y].get(), dofs, connectivity[int_point_number])
    return x_i, w_i, local_dofs

def _get_tangent_diagonal(tangent_contributions, connectivity, dofs):
    """
    Assembles the diagonal entries of the tangent matrix from the tangent contributions,
    connectivity, and degrees of freedom.

    Args:
        tangent_contributions (dict or jnp.ndarray): Tangent contributions from the model function.
        connectivity (dict or jnp.ndarray): Connectivity information for elements.
        dofs (dict or jnp.ndarray): Degrees of freedom.

    Returns:
        jnp.ndarray: The assembled diagonal of the tangent matrix.
    """
    if callable(dofs):
        dofs = dofs(0.)

    # Total number of DOFs
    num_dofs = (
        dofs.size if not isinstance(dofs, dict) else sum(v.size for v in dofs.values())
    )

    if isinstance(dofs, dict):
        keys = dofs.keys()
        n_elems = next(iter(connectivity.values())).shape[0]  # Number of elements

        # Compute field offsets for global DOF numbering
        field_offsets = {}
        current_offset = 0
        for key in keys:
            field_offsets[key] = current_offset
            field_size = dofs[key].size  # Total DOFs in the field
            current_offset += field_size

        # Initialize the global diagonal vector
        diag = jnp.zeros(num_dofs)

        # Iterate over fields to assemble diagonal contributions
        for key in keys:
            # Extract tangent contributions for field [key][key]
            tc = tangent_contributions[key][key]
            # tc has shape: (n_elems, nodes_per_element_i, dofs_per_node_i, nodes_per_element_j, dofs_per_node_j)

            # Reshape tc to (n_elems, element_dofs, element_dofs)
            conn = connectivity[key]  # Shape: (n_elems, nodes_per_element)
            nodes_per_element = conn.shape[1]
            if dofs[key].ndim == 1:
                dofs_per_node = 1
            else:
                dofs_per_node = dofs[key].shape[-1]
            element_dofs = nodes_per_element * dofs_per_node

            # Reshape tc
            tc = tc.reshape(n_elems, element_dofs, element_dofs)

            # Extract diagonal contributions
            diagonal_contributions = jnp.diagonal(
                tc, axis1=1, axis2=2
            )  # Shape: (n_elems, element_dofs)

            # Compute global DOF indices
            dof_local = jnp.arange(dofs_per_node)
            dof_indices = (
                field_offsets[key]
                + conn[:, :, None] * dofs_per_node
                + dof_local[None, None, :]
            )  # Shape: (n_elems, nodes_per_element, dofs_per_node)
            global_dofs = dof_indices.reshape(
                n_elems, element_dofs
            )  # Shape: (n_elems, element_dofs)

            # Flatten indices and values
            diag_indices = global_dofs.flatten().astype(int)
            diag_values = diagonal_contributions.flatten()

            # Sum into the global diagonal vector
            diag = diag.at[diag_indices].add(diag_values)

    else:
        # For the array case
        n_elems = connectivity.shape[0]
        conn = connectivity  # Shape: (n_elems, nodes_per_element)
        nodes_per_element = conn.shape[1]
        dofs_per_node = dofs.shape[-1]
        element_dofs = nodes_per_element * dofs_per_node

        # Reshape tangent_contributions
        tc = tangent_contributions.reshape(n_elems, element_dofs, element_dofs)

        # Extract diagonal contributions
        diagonal_contributions = jnp.diagonal(
            tc, axis1=1, axis2=2
        )  # Shape: (n_elems, element_dofs)

        # Compute global DOF indices
        dof_local = jnp.arange(dofs_per_node)
        dof_indices = (
            conn[:, :, None] * dofs_per_node + dof_local[None, None, :]
        )  # Shape: (n_elems, nodes_per_element, dofs_per_node)
        global_dofs = dof_indices.reshape(
            n_elems, element_dofs
        )  # Shape: (n_elems, element_dofs)

        # Flatten indices and values
        diag_indices = global_dofs.flatten().astype(int)
        diag_values = diagonal_contributions.flatten()

        # Sum into the global diagonal vector
        diag = jnp.zeros(num_dofs)
        diag = diag.at[diag_indices].add(diag_values)

    return diag

def _get_residual(residual_contributions, connectivity, dofs):
    """
    Assembles the global residual vector from the residual contributions,
    connectivity, and degrees of freedom, returning a residual with the same
    structure as dofs.

    Args:
        residual_contributions (dict or jnp.ndarray): Residual contributions from the model function.
        connectivity (dict or jnp.ndarray): Connectivity information for elements.
        dofs (dict or jnp.ndarray): Degrees of freedom.

    Returns:
        dict or jnp.ndarray: The assembled residual vector with the same structure as dofs.
    """
    if callable(dofs):
        dofs = dofs(0.)

    if isinstance(dofs, dict):
        keys = dofs.keys()
        n_elems = connectivity[next(iter(keys))].shape[0]
        
        # Initialize the residual dictionary
        residual = {}

        # Iterate over fields to assemble residual contributions
        for key in keys:
            # Extract residual contributions for field [key]
            rc = residual_contributions[
                key
            ]  # Shape: (n_elems, nodes_per_element, dofs_per_node)

            # Reshape rc to (n_elems, element_dofs)
            conn = connectivity[key]  # Shape: (n_elems, nodes_per_element)
            nodes_per_element = conn.shape[1]
            if dofs[key].ndim == 1:
                dofs_per_node = 1
            else:
                dofs_per_node = dofs[key].shape[-1]
            element_dofs = nodes_per_element * dofs_per_node

            # Reshape rc
            rc = rc.reshape(n_elems, element_dofs)  # Shape: (n_elems, element_dofs)

            # Compute global DOF indices
            dof_local = jnp.arange(dofs_per_node)
            dof_indices = (
                conn[:, :, None] * dofs_per_node + dof_local[None, None, :]
            )  # Shape: (n_elems, nodes_per_element, dofs_per_node)
            global_dofs = dof_indices.reshape(
                n_elems, element_dofs
            )  # Shape: (n_elems, element_dofs)

            # Flatten indices and values
            residual_indices = global_dofs.flatten().astype(int)
            residual_values = rc.flatten()

            # Initialize the residual array for this field
            field_residual = jnp.zeros_like(dofs[key]).flatten()

            # Sum into the field residual vector
            field_residual = field_residual.at[residual_indices].add(residual_values)

            # Reshape back to the original shape
            field_residual = field_residual.reshape(dofs[key].shape)

            # Assign to the residual dictionary
            residual[key] = field_residual

        return residual

    else:
        # For the array case
        n_elems = connectivity.shape[0]
        conn = connectivity  # Shape: (n_elems, nodes_per_element)
        nodes_per_element = conn.shape[1]
        dofs_per_node = dofs.shape[-1]
        element_dofs = nodes_per_element * dofs_per_node

        # Reshape residual_contributions
        rc = residual_contributions.reshape(
            n_elems, element_dofs
        )  # Shape: (n_elems, element_dofs)

        # Compute global DOF indices
        dof_local = jnp.arange(dofs_per_node)
        dof_indices = (
            conn[:, :, None] * dofs_per_node + dof_local[None, None, :]
        )  # Shape: (n_elems, nodes_per_element, dofs_per_node)
        global_dofs = dof_indices.reshape(
            n_elems, element_dofs
        )  # Shape: (n_elems, element_dofs)

        # Flatten indices and values
        residual_indices = global_dofs.flatten().astype(int)
        residual_values = rc.flatten()

        # Initialize the residual array
        residual = jnp.zeros_like(dofs).flatten()

        # Sum into the residual vector
        residual = residual.at[residual_indices].add(residual_values)

        # Reshape back to the original shape
        residual = residual.reshape(dofs.shape)

        return residual

def _make_elem_dofs_fun(dofs, elem):
    """This function takes the function `dofs(t)` and returns basically lambda t: dofs(t)[elem].
    
    The difference is, that when this function is used under jacrev, it will allocate less memory.
    Supports only as many derivatives as are defined via the custom_jvp decorators
    """
    # @custom_jvp
    # def elem_dofs_ttt_f(t):
    #     dofs_ttt_ = jacfwd(jacfwd(jacfwd(dofs)))(t)
    #     elem_dofs_ttt_ = treemap(lambda x, y: x.at[y].get(), dofs_ttt_, elem)
    #     return elem_dofs_ttt_
    # @elem_dofs_ttt_f.defjvp
    # def elem_dofs_ttt_jvp(primals, tangents):
    #     t, = primals
    #     t_dot, = tangents
    #     elem_dofs_tttt = jacfwd(jacfwd(jacfwd(jacfwd(dofs))))(t)[elem]
    #     return elem_dofs_ttt_f(t), treemap(lambda x: x * t_dot, elem_dofs_tttt)
    # @custom_jvp
    # def elem_dofs_tt_f(t):
    #     dofs_tt_ = jacfwd(jacfwd(dofs))(t)
    #     elem_dofs_tt_ = treemap(lambda x, y: x.at[y].get(), dofs_tt_, elem)
    #     return elem_dofs_tt_
    # @elem_dofs_tt_f.defjvp
    # def elem_dofs_tt_jvp(primals, tangents):
    #     t, = primals
    #     t_dot, = tangents
    #     elem_dofs_ttt = elem_dofs_ttt_f(t)
    #     return elem_dofs_tt_f(t), treemap(lambda x: x * t_dot, elem_dofs_ttt)
    # @custom_jvp
    # def elem_dofs_t_f(t):
    #     dofs_t_ = jacfwd(dofs)(t)
    #     elem_dofs_t_ = treemap(lambda x, y: x.at[y].get(), dofs_t_, elem)
    #     return elem_dofs_t_
    # @elem_dofs_t_f.defjvp
    # def elem_dofs_t_jvp(primals, tangents):
    #     t, = primals
    #     t_dot, = tangents
    #     elem_dofs_tt = elem_dofs_tt_f(t)
    #     return elem_dofs_t_f(t), treemap(lambda x: x * t_dot, elem_dofs_tt)
    # @custom_jvp
    # def elem_dofs_f(t):
    #     dofs_ = dofs(t)
    #     elem_dofs_ = treemap(lambda x, y: x.at[y].get(), dofs_, elem)
    #     return elem_dofs_
    # @elem_dofs_f.defjvp
    # def elem_dofs_jvp(primals, tangents):
    #     t, = primals
    #     t_dot, = tangents
    #     elem_dofs_t = elem_dofs_t_f(t)
    #     return elem_dofs_f(t), treemap(lambda x: x * t_dot, elem_dofs_t)
    # return elem_dofs_f


    # @partial(custom_jvp, nondiff_argnums=(1, 2))
    # def elem_dofs_ttt_f(t, dofs, elem):
    #     dofs_ttt_ = jacfwd(jacfwd(jacfwd(dofs)))(t)
    #     elem_dofs_ttt_ = treemap(lambda x, y: x.at[y].get(), dofs_ttt_, elem)
    #     return elem_dofs_ttt_
    # @elem_dofs_ttt_f.defjvp
    # def elem_dofs_ttt_jvp(dofs, elem, primals, tangents):
    #     t, = primals
    #     t_dot, = tangents
    #     elem_dofs_tttt = jacfwd(jacfwd(jacfwd(jacfwd(dofs))))(t)[elem]
    #     return elem_dofs_ttt_f(t, dofs, elem), treemap(lambda x: x * t_dot, elem_dofs_tttt)
    # @partial(custom_jvp, nondiff_argnums=(1, 2))
    # def elem_dofs_tt_f(t, dofs, elem):
    #     dofs_tt_ = jacfwd(jacfwd(dofs))(t)
    #     elem_dofs_tt_ = treemap(lambda x, y: x.at[y].get(), dofs_tt_, elem)
    #     return elem_dofs_tt_
    # @elem_dofs_tt_f.defjvp
    # def elem_dofs_tt_jvp(dofs, elem, primals, tangents):
    #     t, = primals
    #     t_dot, = tangents
    #     elem_dofs_ttt = elem_dofs_ttt_f(t, dofs, elem)
    #     return elem_dofs_tt_f(t, dofs, elem), treemap(lambda x: x * t_dot, elem_dofs_ttt)
    @partial(custom_jvp, nondiff_argnums=(1, 2))
    def elem_dofs_f(t, dofs, elem):
        dofs_ = dofs(t)
        elem_dofs_ = treemap(lambda x, y: x[y], dofs_, elem)
        return elem_dofs_
    @elem_dofs_f.defjvp
    def elem_dofs_jvp(dofs, elem, primals, tangents):
        t, = primals
        t_dot, = tangents

        # @partial(custom_jvp, nondiff_argnums=(1, 2))
        # def elem_dofs_t_f(t, dofs, elem):
        #     dofs_t_ = jacfwd(dofs)(t)
        #     elem_dofs_t_ = treemap(lambda x, y: x.at[y].get(), dofs_t_, elem)
        #     return elem_dofs_t_
        # @elem_dofs_t_f.defjvp
        # def elem_dofs_t_jvp(dofs, elem, primals, tangents):
        #     t, = primals
        #     t_dot, = tangents
        #     # elem_dofs_tt = elem_dofs_tt_f(t, dofs, elem)
        #     elem_dofs_tt = jacfwd(jacfwd(dofs))(t)[elem]
        #     return elem_dofs_t_f(t, dofs, elem), treemap(lambda x: x * t_dot, elem_dofs_tt)

        # elem_dofs_t = elem_dofs_t_f(t, dofs, elem)
        elem_dofs_t = treemap(lambda a, b: a[b], jacfwd(dofs)(t), elem)
        return elem_dofs_f(t, dofs, elem), treemap(lambda x: x * t_dot, elem_dofs_t)
    local_dofs_fun = lambda t: elem_dofs_f(t, dofs, elem)
    return local_dofs_fun

def _batched_map(fun, elem_numbers, connectivity):
    body_fun = lambda i: fun(elem_numbers[i], jax.tree.map(lambda x: x[i], connectivity))
    num_dofs_per_elem = jax.eval_shape(lambda i: dict_flatten(body_fun(i)), 0).shape[0]
    return jax.lax.map(body_fun, jnp.arange(elem_numbers.shape[0]), batch_size=int(64000/num_dofs_per_elem))


### General assembling functions
[docs]@jit_with_docstring(static_argnames=["static_settings"], possibly_static_argnames=['dofs']) def integrate_functional(dofs, settings, static_settings): """ Integrate functional as sum over set of domains. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. Returns: float: The integrated functional value for current dofs. """ # Loop over all sets of integration points/ domains num_sets = len(static_settings["assembling mode"]) integrated_functional = 0 for set in range(num_sets): assembling_mode = static_settings["assembling mode"][set] if assembling_mode == "dense": integrated_functional += dense_integrate_functional( dofs, settings, static_settings, set ) elif assembling_mode == "sparse": integrated_functional += sparse_integrate_functional( dofs, settings, static_settings, set ) elif assembling_mode == "user potential": integrated_functional += user_potential_integrate_functional( dofs, settings, static_settings, set ) else: assert ( False ), "Assembling mode can be either 'sparse' or 'dense' in integrate_functional" return integrated_functional
[docs]@jit_with_docstring(static_argnames=["static_settings"], possibly_static_argnames=['dofs']) def assemble_residual(dofs, settings, static_settings): """ Assemble residuals over set of domains. Args: dofs (jnp.ndarray or dict or callable): Degrees of freedom. Can be a function of time for transient problems in combination with user_residuals. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. Returns: jnp.ndarray: The assembled residual. """ # Loop over all sets of integration points/ domains num_sets = len(static_settings["assembling mode"]) if isinstance(dofs(0.), dict) if callable(dofs) else isinstance(dofs, dict): assert all([isinstance(settings['connectivity'][0], dict), isinstance(settings['node coordinates'], dict)]), \ "If the DOFs are a dict, the connectivity, node coordinates, dirichlet dofs, and dirichlet conditions must also be dicts." if callable(dofs): integrated_residual = dict_zeros_like(dofs(0.)) else: integrated_residual = dict_zeros_like(dofs) for set in range(num_sets): assembling_mode = static_settings["assembling mode"][set] if assembling_mode == "dense": add = dense_assemble_residual(dofs, settings, static_settings, set) elif assembling_mode == "sparse": add = sparse_assemble_residual(dofs, settings, static_settings, set) elif assembling_mode == "user potential": add = user_potential_assemble_residual(dofs, settings, static_settings, set) elif assembling_mode == "user residual": add = user_residual_assemble_residual(dofs, settings, static_settings, set) elif assembling_mode == "user element": add = user_element_assemble_residual(dofs, settings, static_settings, set) else: assert ( False ), "Assembling mode can be either 'sparse', 'dense' or 'user element'" # Handle both cases dict and jnp.ndarray if isinstance(add, dict): integrated_residual = treemap(lambda x, y: x + y, integrated_residual, add) else: integrated_residual += add return integrated_residual
[docs]@jit_with_docstring(static_argnames=["static_settings"], possibly_static_argnames=['dofs']) def assemble_tangent_diagonal(dofs, settings, static_settings): """ Assemble the diagonal of the tangent matrix. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. Returns: jnp.ndarray: The diagonal of the assembled tangent matrix. """ # Loop over all sets of integration points/ domains num_sets = len(static_settings["assembling mode"]) if callable(dofs): tangent_diagonal = jnp.zeros_like(dict_flatten(dofs(0.))) else: tangent_diagonal = jnp.zeros_like(dict_flatten(dofs)) for set in range(num_sets): assembling_mode = static_settings["assembling mode"][set] if assembling_mode == "sparse": tangent_diagonal += sparse_assemble_tangent_diagonal( dofs, settings, static_settings, set ) elif assembling_mode == "user potential": tangent_diagonal += user_potential_assemble_tangent_diagonal( dofs, settings, static_settings, set ) elif assembling_mode == "user residual": tangent_diagonal += user_residual_assemble_tangent_diagonal( dofs, settings, static_settings, set ) elif assembling_mode == "user element": tangent_diagonal += user_element_assemble_tangent_diagonal( dofs, settings, static_settings, set ) else: assert ( False ), "Assembling mode for assembling tangent diagonal supports currently only 'sparse' and 'user element'" return tangent_diagonal
[docs]@jit_with_docstring(static_argnames=["static_settings"], possibly_static_argnames=['dofs']) def assemble_tangent(dofs, settings, static_settings): """ Assemble the full (possibly sparse) tangent matrix. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. Returns: jnp.ndarray or sparse matrix: The assembled tangent matrix. """ # ToDo: add symmetric mode num_sets = len(static_settings["assembling mode"]) one_dense = "dense" in static_settings["assembling mode"] if isinstance(dofs, dict): num_dofs = sum(v.size for v in dofs.values()) else: num_dofs = dofs.size try: sparsity_pattern = static_settings["known sparsity pattern"] except KeyError: sparsity_pattern = "none" match sparsity_pattern: case "none": if one_dense: integrated_tangent = jnp.zeros((num_dofs, num_dofs)) else: integrated_tangent = sparse.empty( (num_dofs, num_dofs), dtype=float, index_dtype=jnp.int_ ) # Loop over all sets of integration points/ domains for set in range(num_sets): assembling_mode = static_settings["assembling mode"][set] if assembling_mode == "dense": integrated_tangent += dense_assemble_tangent( dofs, settings, static_settings, set ) else: if assembling_mode == "sparse": add = sparse_assemble_tangent( dofs, settings, static_settings, set ) elif assembling_mode == "user potential": add = user_potential_assemble_tangent( dofs, settings, static_settings, set ) elif assembling_mode == "user residual": add = user_residual_assemble_tangent( dofs, settings, static_settings, set ) elif assembling_mode == "user element": add = user_element_assemble_tangent( dofs, settings, static_settings, set ) else: assert ( False ), "Assembling mode can be either 'sparse', 'dense' or 'user element'" if one_dense: integrated_tangent += add.todense() else: integrated_tangent += add case "diagonal": # # Compute the diagonal tangent with sparsejac (not the diagonal of a tangent that is not diagonal) # residual_fun = lambda flat_dofs: assemble_residual(flat_dofs.reshape(dofs.shape), settings, static_settings).flatten() # with jax.ensure_compile_time_eval(): # data_and_indices = (jnp.ones((num_dofs,)), # vmap(lambda i: jnp.asarray([i, i]))(jnp.arange(0, num_dofs))) # mat_shape = (num_dofs,num_dofs) # sparsity = sparse.BCOO(data_and_indices, shape=mat_shape) # sparse_diag_fun = sparsejac.jacfwd(residual_fun, sparsity=sparsity) # diag = sparse_diag_fun(dofs.flatten()) # return diag diag = dict_flatten( assemble_tangent_diagonal(dofs, settings, static_settings) ) indices = vmap(lambda i: jnp.asarray([i, i]))(jnp.arange(0, num_dofs)) data_and_indices = (diag, indices) matrix_shape = (num_dofs, num_dofs) diag_mat = sparse.BCOO(data_and_indices, shape=matrix_shape) return diag_mat case _: assert False, "'known sparsity pattern' mode is not implemented." return integrated_tangent
### Dense assembling
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def dense_integrate_functional(dofs, settings, static_settings, set): """ Dense integration of functional of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: float: The integrated functional value. """ x_int = settings["integration coordinates"][set] w_int = settings["integration weights"][set] int_point_numbers = jnp.arange(0, x_int.shape[0], 1) def at_int_point(int_point_number): x_i = x_int[int_point_number] w_i = w_int[int_point_number] return variational_schemes.functional_at_int_point(x_i, w_i, int_point_number, dofs, settings, static_settings, set) functional_at_int_point_vj = vmap(at_int_point, (0,)) integrated_functional = functional_at_int_point_vj(int_point_numbers).sum() return integrated_functional
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def dense_assemble_residual(dofs, settings, static_settings, set): """ Dense assembly of residual of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The assembled residual. """ return jacrev(dense_integrate_functional)(dofs, settings, static_settings, set)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def dense_assemble_tangent(dofs, settings, static_settings, set): """ Dense assembly of tangent of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The assembled tangent matrix. """ assert isinstance( dofs, jnp.ndarray ), "Dense mode of tangent assembly does currently not support dofs as dicts." size = dict_flatten(dofs).size tangent = hessian(dense_integrate_functional)(dofs, settings, static_settings, set) return dict_flatten(tangent).reshape((size, size))
### Sparse assembling
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def sparse_integrate_functional(dofs, settings, static_settings, set): """ Sparse integration of functional of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: float: The integrated functional value. """ connectivity, variational_scheme, x_int, w_int, int_point_numbers = _get_element_quantities_2(dofs, settings, static_settings, set) def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return variational_schemes.functional_at_int_point( x_i, w_i, int_point_number, local_dofs, settings, static_settings, set ) functional_at_int_point_vj = vmap(func_at_int_pt, (0,)) return functional_at_int_point_vj(int_point_numbers).sum()
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def sparse_assemble_residual(dofs, settings, static_settings, set): """ Sparse assembly of residual of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The assembled residual. """ connectivity, variational_scheme, x_int, w_int, int_point_numbers = _get_element_quantities_2(dofs, settings, static_settings, set) if ( variational_scheme == "least square pde loss" or variational_scheme == "least square function approximation" ): def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return jacrev(variational_schemes.functional_at_int_point, argnums=3)( x_i, w_i, int_point_number, local_dofs, settings, static_settings, set ) residual_at_int_point_vj = vmap(func_at_int_pt, (0,)) residual_contributions = residual_at_int_point_vj(int_point_numbers) elif variational_scheme == "strong form galerkin": # Direct implementation of residual, e.g. for Galerkin method def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return variational_schemes.direct_residual_at_int_point( x_i, w_i, int_point_number, local_dofs, settings, static_settings, set ) residual_at_int_point_vj = vmap(func_at_int_pt, (0,)) residual_contributions = residual_at_int_point_vj(int_point_numbers) elif variational_scheme == "weak form galerkin": # Pass local_dofs twice (assuming Bubnov Galerkin...) def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return variational_schemes.residual_from_deriv_at_int_point( x_i, w_i, int_point_number, local_dofs, local_dofs, settings, static_settings, set ) residual_at_int_point_vj = vmap(func_at_int_pt, (0,)) residual_contributions = residual_at_int_point_vj(int_point_numbers) else: raise KeyError("Variational scheme not or wrongly specified!") return _get_residual(residual_contributions, connectivity, dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def sparse_assemble_tangent_diagonal(dofs, settings, static_settings, set): """ Sparse assembly of the diagonal of the tangent matrix for specified set. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The diagonal of the assembled tangent matrix. """ connectivity, variational_scheme, x_int, w_int, int_point_numbers = _get_element_quantities_2(dofs, settings, static_settings, set) # Compute tangent contributions if ( variational_scheme == "least square pde loss" or variational_scheme == "least square function approximation" ): def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return jacfwd(jacrev(variational_schemes.functional_at_int_point, argnums=3), argnums=3)( x_i, w_i, int_point_number, local_dofs, settings, static_settings, set ) at_int_point_vj = vmap(func_at_int_pt, (0,)) tangent_contributions = at_int_point_vj(int_point_numbers) elif variational_scheme == "strong form galerkin": def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return jacfwd(variational_schemes.direct_residual_at_int_point, argnums=3)( x_i, w_i, int_point_number, local_dofs, settings, static_settings, set ) at_int_point_vj = vmap(func_at_int_pt, (0,)) tangent_contributions = at_int_point_vj(int_point_numbers) elif variational_scheme == "weak form galerkin": def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return jacfwd(variational_schemes.residual_from_deriv_at_int_point, argnums=3)( x_i, w_i, int_point_number, local_dofs, local_dofs, settings, static_settings, set ) at_int_point_vj = vmap(func_at_int_pt, (0,)) tangent_contributions = at_int_point_vj(int_point_numbers) else: raise KeyError("Variational scheme mode not or wrongly specified!") return _get_tangent_diagonal(tangent_contributions, connectivity, dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def sparse_assemble_tangent(dofs, settings, static_settings, set): """ Sparse assembly of the full tangent matrix of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jax.experimental.sparse.BCOO: The assembled tangent matrix. """ connectivity, variational_scheme, x_int, w_int, int_point_numbers = _get_element_quantities_2(dofs, settings, static_settings, set) # Compute tangent contributions if ( variational_scheme == "least square pde loss" or variational_scheme == "least square function approximation" ): def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return jacfwd(jacrev(variational_schemes.functional_at_int_point, argnums=3), argnums=3)( x_i, w_i, int_point_number, local_dofs, settings, static_settings, set ) at_int_point_vj = vmap(func_at_int_pt, (0,)) tangent_contributions = at_int_point_vj(int_point_numbers) elif variational_scheme == "strong form galerkin": def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return jacfwd(variational_schemes.direct_residual_at_int_point, argnums=3)( x_i, w_i, int_point_number, local_dofs, settings, static_settings, set ) at_int_point_vj = vmap(func_at_int_pt, (0,)) tangent_contributions = at_int_point_vj(int_point_numbers) elif variational_scheme == "weak form galerkin": def func_at_int_pt(int_point_number): x_i, w_i, local_dofs = _extract_local_dofs_and_coor_2(dofs, int_point_number, x_int, w_int, connectivity) return jacfwd(variational_schemes.residual_from_deriv_at_int_point, argnums=3)( x_i, w_i, int_point_number, local_dofs, local_dofs, settings, static_settings, set ) at_int_point_vj = vmap(func_at_int_pt, (0,)) tangent_contributions = at_int_point_vj(int_point_numbers) else: raise KeyError("Variational scheme mode not or wrongly specified!") # Assembling (without summing duplicates) data = dict_flatten(tangent_contributions) indices = _get_indices(connectivity, dofs) num_dofs = ( dofs.size if not isinstance(dofs, dict) else sum(v.size for v in dofs.values()) ) tangent_matrix = sparse.BCOO((data, indices), shape=(num_dofs, num_dofs)) return tangent_matrix
### Assembling for user potentials
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_potential_integrate_functional(dofs, settings, static_settings, set): """ Assembly of potential for custom user definition of specified domain. Args: dofs (jnp.ndarray or dict or callable): Degrees of freedom. Can be a function of time for transient problems. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: float: value of functional integrated over set of elements """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return model_fun(local_dofs, local_node_coor, elem_number, settings, static_settings, set) functional_contributions = vmap(element_residual, (0, 0), (0))(elem_numbers, connectivity) return functional_contributions.sum()
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_potential_assemble_residual(dofs, settings, static_settings, set): """ Assembly of residual for custom user potential of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The assembled residual. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return jacrev(model_fun)(local_dofs, local_node_coor, elem_number, settings, static_settings, set) # residual_contributions = vmap(element_residual, (0, 0))(elem_numbers, connectivity) residual_contributions = _batched_map(element_residual, elem_numbers, connectivity) return _get_residual(residual_contributions, connectivity, dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_potential_assemble_tangent_diagonal(dofs, settings, static_settings, set): """ Assembly of the diagonal of the tangent matrix for custom user potential of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The diagonal of the assembled tangent matrix. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the dofs from the global dofs and vmap only over connectivity def element_tangent(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return jacfwd(jacrev(model_fun))(local_dofs, local_node_coor, elem_number, settings, static_settings, set) tangent_contributions = vmap(element_tangent, (0, 0), (0))(elem_numbers, connectivity) return _get_tangent_diagonal(tangent_contributions, connectivity, dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_potential_assemble_tangent(dofs, settings, static_settings, set): """ Assembly of the full (sparse) tangent matrix for custom user potential of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jax.experimental.sparse.BCOO: The assembled tangent matrix. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the dofs from the global dofs and vmap only over connectivity def element_tangent(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return jacfwd(jacrev(model_fun))(local_dofs, local_node_coor, elem_number, settings, static_settings, set) # tangent_contributions = vmap(element_tangent, (0, 0), (0))(elem_numbers, connectivity) body_fun = lambda i: element_tangent(elem_numbers[i], jax.tree.map(lambda x: x[i], connectivity)) num_dofs_per_elem = jax.eval_shape(lambda i: dict_flatten(body_fun(i)), 0).shape[0] tangent_contributions = jax.lax.map(body_fun, jnp.arange(elem_numbers.shape[0]), batch_size=int(64000/num_dofs_per_elem)) data = dict_flatten(tangent_contributions) indices = _get_indices(connectivity, dofs) num_dofs = ( dofs.size if not isinstance(dofs, dict) else sum(v.size for v in dofs.values()) ) tangent_matrix = sparse.BCOO((data, indices), shape=(num_dofs, num_dofs)) return tangent_matrix
@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def _user_potential_assemble_r_and_t(dofs, settings, static_settings, set): model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the dofs from the global dofs and vmap only over connectivity def element_r_and_t(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) residual_fun = lambda x: jacrev(model_fun)(x, local_node_coor, elem_number, settings, static_settings, set) elem_res = residual_fun(local_dofs) elem_tan = jacfwd(residual_fun)(local_dofs) # def residual_and_tangent_linearize(residual_fun, local_dofs): # primals, lin_fun = linearize(residual_fun, local_dofs) # flat_local_dofs = dict_flatten(local_dofs) # n = flat_local_dofs.shape[0] # identity = jnp.eye(n) # jacobian = vmap(lambda v: lin_fun(reshape_as(v, local_dofs)))(identity) # return primals, jacobian # elem_res, elem_tan = residual_and_tangent_linearize(residual_fun, local_dofs) # def residual_and_tangent(residual_fun, local_dofs): # flat_local_dofs = dict_flatten(local_dofs) # n = flat_local_dofs.shape[0] # identity = jnp.eye(n) # def jvp_with_flat_tangent(v): # tangent_pytree = reshape_as(v, local_dofs) # return jvp(residual_fun, (local_dofs,), (tangent_pytree,)) # primals, elem_tan = vmap(jvp_with_flat_tangent)(identity) # elem_res = treemap(lambda x: x[0], primals) # return elem_res, elem_tan # elem_res, elem_tan = residual_and_tangent(residual_fun, local_dofs) # def residual_and_tangent_vjp(residual_fun, local_dofs): # primals, vjp_fun = vjp(residual_fun, local_dofs) # flat_res = dict_flatten(primals) # m = flat_res.shape[0] # flat_local = dict_flatten(local_dofs) # n = flat_local.shape[0] # identity = jnp.eye(m) # jacobian_rows = vmap( # lambda v: dict_flatten(vjp_fun(reshape_as(v, primals))[0]) # )(identity) # jacobian = jacobian_rows.T # return primals, jacobian # elem_res, elem_tan = residual_and_tangent_vjp(residual_fun, local_dofs) return elem_res, elem_tan all_contributions = vmap(element_r_and_t, (0, 0), (0, 0)) residual_contributions, tangent_contributions = all_contributions(elem_numbers, connectivity) residual = _get_residual(residual_contributions, connectivity, dofs) data = dict_flatten(tangent_contributions) indices = _get_indices(connectivity, dofs) num_dofs = ( dofs.size if not isinstance(dofs, dict) else sum(v.size for v in dofs.values()) ) tangent_matrix = sparse.BCOO((data, indices), shape=(num_dofs, num_dofs)) return residual, tangent_matrix ### Assembling for user residuals
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_residual_assemble_residual(dofs, settings, static_settings, set): """ Assembly of residual for custom user residual of specified domain. Args: dofs (jnp.ndarray or dict or callable): Degrees of freedom. Can be a function of time for transient problems. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The assembled residual. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return model_fun(local_dofs, local_node_coor, elem_number, settings, static_settings, set) # residual_contributions = vmap(element_residual, (0, 0), (0))(elem_numbers, connectivity) body_fun = lambda i: element_residual(elem_numbers[i], jax.tree.map(lambda x: x[i], connectivity)) num_dofs_per_elem = jax.eval_shape(lambda i: dict_flatten(body_fun(i)), 0).shape[0] residual_contributions = jax.lax.map(body_fun, jnp.arange(elem_numbers.shape[0]), batch_size=int(64000/num_dofs_per_elem)) return _get_residual(residual_contributions, connectivity, dofs(0.) if callable(dofs) else dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_residual_assemble_tangent_diagonal(dofs, settings, static_settings, set): """ Assembly of the diagonal of the tangent matrix for custom user residual of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The diagonal of the assembled tangent matrix. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return jacfwd(model_fun)(local_dofs, local_node_coor, elem_number, settings, static_settings, set) tangent_contributions = vmap(element_residual, (0, 0), (0))(elem_numbers, connectivity) return _get_tangent_diagonal(tangent_contributions, connectivity, dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_residual_assemble_tangent(dofs, settings, static_settings, set): """ Assembly of the full (sparse) tangent matrix for custom user residual of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jax.experimental.sparse.BCOO: The assembled tangent matrix. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return jacfwd(model_fun)(local_dofs, local_node_coor, elem_number, settings, static_settings, set) tangent_contributions = vmap(element_residual, (0, 0), (0))(elem_numbers, connectivity) data = dict_flatten(tangent_contributions) indices = _get_indices(connectivity, dofs) if callable(dofs): dofs0 = dofs(0.) num_dofs = ( dofs0.size if not isinstance(dofs0, dict) else sum(v.size for v in dofs0.values()) ) else: num_dofs = ( dofs.size if not isinstance(dofs, dict) else sum(v.size for v in dofs.values()) ) tangent_matrix = sparse.BCOO((data, indices), shape=(num_dofs, num_dofs)) return tangent_matrix
### Assembling for user elements
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_element_assemble_residual(dofs, settings, static_settings, set): """ Assembly of residual for custom user element of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The assembled residual. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return model_fun(local_dofs, local_node_coor, elem_number, settings, static_settings, "residual", set) residual_contributions = vmap(element_residual, (0, 0), (0))(elem_numbers, connectivity) return _get_residual(residual_contributions, connectivity, dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_element_assemble_tangent_diagonal(dofs, settings, static_settings, set): """ Assembly of the diagonal of the tangent matrix for custom user element of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jnp.ndarray: The diagonal of the assembled tangent matrix. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return model_fun(local_dofs, local_node_coor, elem_number, settings, static_settings, "tangent", set) tangent_contributions = vmap(element_residual, (0, 0), (0))(elem_numbers, connectivity) return _get_tangent_diagonal(tangent_contributions, connectivity, dofs)
[docs]@jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def user_element_assemble_tangent(dofs, settings, static_settings, set): """ Assembly of the full (sparse) tangent matrix for custom user element of specified domain. Args: dofs (jnp.ndarray or dict): Degrees of freedom. settings (dict): Settings dictionary. static_settings (flax.core.FrozenDict): Static settings as frozen dictionary. set (int): The domain number. Returns: jax.experimental.sparse.BCOO: The assembled tangent matrix. """ model_fun, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) # Modify the model_fun such that it extracts the DOFs from the global dofs and vmap only over connectivity def element_residual(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return model_fun(local_dofs, local_node_coor, elem_number, settings, static_settings, "tangent", set) tangent_contributions = vmap(element_residual, (0, 0))(elem_numbers, connectivity) data = dict_flatten(tangent_contributions) indices = _get_indices(connectivity, dofs) num_dofs = ( dofs.size if not isinstance(dofs, dict) else sum(v.size for v in dofs.values()) ) tangent_matrix = sparse.BCOO((data, indices), shape=(num_dofs, num_dofs)) return tangent_matrix
### Internal variable update function @jit_with_docstring(static_argnames=["static_settings", "set"], possibly_static_argnames=['dofs']) def get_int_var_updates(dofs, settings, static_settings, set): """ Get the internal variables for a specified domain for all elements and integration points. Similar Structure as the assemble_residual functions, but uses 'int var updates' instead of 'model' in order to compute the per element and Gauss point internal variables. """ _, x_nodes, elem_numbers, connectivity = _get_element_quantities(dofs, settings, static_settings, set) int_var_updates = static_settings['int var updates'][set] def local_update_fun(elem_number, node_list): local_dofs, local_node_coor = _extract_local_dofs_and_coor(dofs, node_list, x_nodes) return int_var_updates(local_dofs, local_node_coor, elem_number, settings, static_settings, set) return jax.vmap(local_update_fun, (0, 0), (0))(elem_numbers, connectivity)