Source code for autopdex.models

# models.py
# Copyright (C) 2024 Tobias Bode
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

"""
This module contains premade jax-transformable partial differential equation (PDE) formulations, weak forms, user-defined elements, and time integration procedures.

The models-functions return callables with appropriate arguments for the use required in the other modules of AutoPDEx. These functions can then be passed e.g. to the solver via static_settings['model'].

For instance, a typical weak form function has the following arguments:
  - x (jnp.ndarray): Spatial coordinates at the integration point.
  - ansatz (function): Ansatz function representing the field variable (e.g., displacement or temperature).
  - test_ansatz (function): Test ansatz function representing the virtual displacement or virtual temperature.
  - settings (dict): Settings for the computation.
  - static_settings (dict): Static settings for the computation.
  - int_point_number (int): Integration point number.
  - set: Number of domain.

Note: Some of the models work for degrees of freedoms (dofs) as a jnp.ndarray
and some work for dicts of jnp.ndarrays. It is specified in each docstring.
"""

from inspect import signature

import jax
import jax.numpy as jnp

from autopdex import solution_structures, utility


### Linear equations
[docs]def transport_equation(c): """ Transport equation: du/dt + c * du/dx = 0 Args: c (float): Transport coefficient. Returns: pde_fun (function): Function to compute the PDE residual. Notes: - This models works for DOFs as a jnp.ndarray. """ def pde_fun(x, ansatz, settings, static_settings, int_point_number, set): u = ansatz u_derivative = jax.jacrev(u)(x) u_x = u_derivative[0] u_t = u_derivative[1] return u_t + c * u_x return pde_fun
[docs]def poisson(coefficient_fun=lambda x, settings: 1.0, source_fun=None): """ Poisson equation in n dimensions: coefficient * Laplace(Theta) + source_term = 0 Args: coefficient_fun (function): Function to compute the coefficient source_fun (function (dependend on the position jnp.ndarray x), optional): Function to compute the source term, defaults to None. Returns: pde_fun (function): Function to compute the PDE residual. Notes: - This models works for DOFs as a jnp.ndarray. """ def pde_fun(x, ansatz, settings, static_settings, int_point_number, set): hessian = jax.jacfwd(jax.jacrev(ansatz))(x) # Coefficient coefficient = coefficient_fun(x, settings) if source_fun == None: pde = coefficient * jnp.trace(hessian) else: pde = coefficient * jnp.trace(hessian) + source_fun(x) return pde return pde_fun
[docs]def poisson_weak(coefficient_fun=lambda x, settings: 1.0, source_fun=None): """ Poisson equation in n dimensions as weak form. Args: coefficient_fun (function): Function to compute the coefficient of the laplacian, defaults to 1. source_fun (function (dependend on the position jnp.ndarray x), optional): Function to compute the source term, defaults to None. Returns: pde_fun (function): Function to compute the weak form of the PDE. Notes: - This models works for DOFs as a jnp.ndarray. """ def pde_fun( x, ansatz, test_ansatz, settings, static_settings, int_point_number, set ): # Virtual temperature change and gradient var_theta = test_ansatz(x) var_theta_grad = jax.jacrev(test_ansatz)(x) # Temperature gradient theta_grad = jax.jacrev(ansatz)(x) # Coefficient coefficient = coefficient_fun(x, settings) if source_fun == None: pde = -coefficient * jnp.dot(theta_grad, var_theta_grad) else: pde = ( -coefficient * jnp.dot(theta_grad, var_theta_grad) + source_fun(x) * var_theta ) return pde return pde_fun
[docs]def poisson_fos(spacing, coefficient_fun=lambda x, settings: 1.0, source_fun=None): """ Poisson equation as set of first order models augmented by curl(v)=0 First field of solution space has to be scalar function; remaining fields form the gradient, i.e. theta = lambda x: ansatz(x)[0] grad_theta = lambda x: ansatz(x)[1:] n_dim = 2 or 3 Args: spacing (float): Spacing for the weights. coefficient_fun (function): Function to compute the coefficient of the laplacian, defaults to 1. source_fun (function, optional): Function to compute the source term, defaults to None. Returns: pde_fun (function): Function to compute the first order PDE system residuals. Notes: - This models works for DOFs as a jnp.ndarray. """ def pde_fun(x, ansatz, settings, static_settings, int_point_number, set): n_dim = x.shape[0] # Coefficient coefficient = coefficient_fun(x, settings) # Primary fields theta = lambda t: ansatz(t)[0] grad_theta = lambda t: ansatz(t)[1:] # v grad_v = jax.jacfwd(grad_theta)(x) # Poisson equation if source_fun == None: pde_1 = coefficient * jnp.asarray([jnp.trace(grad_v)]) else: pde_1 = coefficient * jnp.asarray([jnp.trace(grad_v) + source_fun(x)]) # Substituted derivatives pde_2 = grad_theta(x) - jax.jacrev(theta)(x) # Additional constraint in order to make functional fully H1 coercive if n_dim == 2: curl_v = jnp.asarray([grad_v[1, 0] - grad_v[0, 1]]) elif n_dim == 3: curl_v = jnp.asarray( [ grad_v[2, 1] - grad_v[1, 2], grad_v[0, 2] - grad_v[2, 0], grad_v[1, 0] - grad_v[0, 1], ] ) else: assert False, "This dimensionality is not supported by laplacian_fos" pde_3 = curl_v w1 = 1 w2 = 1 / spacing w3 = 1 / spacing return jnp.concatenate([w1 * pde_1, w2 * pde_2, w3 * pde_3]).flatten() return pde_fun
[docs]def heat_equation(diffusivity_fun): """ Space-time heat equation with thermal diffusivity alpha. Last dimension is the time dimension, remaining dimensions are spatial dimensions Args: diffusivity_fun (function): Function to compute the thermal diffusivity. Returns: heat_fun (function): Function to compute the PDE residual. Notes: - This models works for DOFs as a jnp.ndarray. """ def heat_fun(x, ansatz, settings, static_settings, int_point_number, set): n_dim = x.shape[0] alpha = diffusivity_fun(x) rate = jax.jacfwd(ansatz)(x)[-1] # du / dt spatial_derivatives = jax.jacfwd(jax.jacrev(ansatz))(x).diagonal()[: n_dim - 1] laplacian = spatial_derivatives.sum() # Laplace(u) pde = rate - alpha * laplacian return pde return heat_fun
[docs]def heat_equation_fos(diffusivity_fun): """Space-time heat equation with thermal diffusivity alpha as a first order system with curl-augmentation for more than 2 spatial dimensions - supports n_dim = 2 to n_dim = 3, i.e. 1 to 3 spatial dimensions plus one time dimension - number of fiels: n_dim (1 temperature field, n_dim-1 spatial gradients) - last dimension is the time dimension, the remaining ones are spacial dimensions e.g. n_dim = 2: - first dimension is x-dimension, second dimension is time dimension - first field is temperature, second field is derivative of temperature with respect to x Args: diffusivity_fun (function): Function to compute the thermal diffusivity. Returns: heat_fun (function): Function to compute the PDE residual. Notes: - This models works for DOFs as a jnp.ndarray. """ def heat_fun(x, ansatz, settings, static_settings, int_point_number, set): n_dim = x.shape[0] alpha = diffusivity_fun(x) temperature = lambda t: ansatz(t)[0] # u derivative_field = lambda t: ansatz(t)[1:] # v derivatives = jax.jacrev(temperature)(x) # [du/dx, du/dt] (flat) # Heat equation with substituted spatial derivatives v rate = derivatives[-1] # du / dt grad_v = jax.jacfwd(derivative_field)(x)[:, : n_dim - 1] laplacian = grad_v.diagonal().sum() # div(v) pde_1 = jnp.asarray([rate - alpha * laplacian]) # Substituted spatial derivatives: v - du / dx pde_2 = derivative_field(x) - derivatives[: n_dim - 1] # Augmentation for stabilization if n_dim == 2: curl_v = jnp.asarray([0.0]) elif n_dim == 3: curl_v = jnp.asarray([grad_v[1, 0] - grad_v[0, 1]]) elif n_dim == 4: curl_v = jnp.asarray( [ grad_v[2, 1] - grad_v[1, 2], grad_v[0, 2] - grad_v[2, 0], grad_v[1, 0] - grad_v[0, 1], ] ) else: assert False, "This dimensionality is not supported by heat_equation_fos" pde_3 = curl_v return jnp.concatenate([pde_1, pde_2, pde_3]).flatten() return heat_fun
[docs]def d_alembert(wave_number_fun): """ Space-time d'Alembert operator 1 temporal [-1] and (n-1) spatial [:n-1] dimension with wave number c. Args: wave_number_fun (function): Function to compute the wave number. Returns: d_alembert_fun (function): Function to compute the PDE residual. Notes: - This models works for DOFs as a jnp.ndarray. """ def d_alembert_fun(x, ansatz, settings, static_settings, int_point_number, set): c = wave_number_fun(x) hessian = jax.jacfwd(jax.jacrev(ansatz))(x) hessian_diagonal = jnp.diagonal() coefficients = -jnp.ones_like(hessian_diagonal) coefficients = coefficients.at[-1].set(1 / c**2) pde = jnp.dot(hessian_diagonal, coefficients) return pde return d_alembert_fun
[docs]def d_alembert_fos(wave_number_fun, spacing): """ Constructs a first-order system for the space-time d'Alembert operator with wave speed c. This function returns a PDE function representing the d'Alembert operator for one temporal dimension and (n-1) spatial dimensions, with augmented conditions to ensure a fully H1 coercive functional. Args: wave_number_fun (function): A function that computes the wave number (inverse of wave speed) given spatial coordinates. spacing (float): The spatial spacing parameter used to scale the augmented conditions. Returns: function: A function that evaluates the first-order system of PDEs for the d'Alembert operator, which includes: - The primary d'Alembert operator. - Substituted derivative fields. - Additional constraints for stabilization (curl conditions). Notes: - The last dimension of the input coordinates x is considered the time dimension; the remaining dimensions are spatial dimensions. - This models works for DOFs as a jnp.ndarray. """ def d_alembert_fun(x, ansatz, settings, static_settings, int_point_number, set): # Initialization n_dim = x.shape[0] c = wave_number_fun(x) field = lambda t: ansatz(t)[0] derivative_field = lambda t: ansatz(t)[1:] # First pde: D' Alembert operator hessian = jax.jacfwd(derivative_field)(x) hessian_diagonal = hessian.diagonal() coefficients = -jnp.ones_like(hessian_diagonal) coefficients = coefficients.at[0].set( 1 / c ) # The other c is already in the other pde pde_1 = jnp.asarray([jnp.dot(hessian_diagonal, coefficients)]) # Second pde: Gradient field scaled_grad = jax.jacrev(field)(x) scaled_grad = scaled_grad.at[-1].multiply(1 / c) pde_2 = derivative_field(x) - scaled_grad # Additional constraint in order to make functional fully H1 coercive # grad_v = derivative_field(x) grad_v = jax.jacfwd(derivative_field)(x) if n_dim == 2: curl_v = jnp.asarray([0.0]) elif n_dim == 3: curl_v = jnp.asarray([grad_v[1, 0] - grad_v[0, 1]]) elif n_dim == 4: curl_v = jnp.asarray( [ grad_v[2, 1] - grad_v[1, 2], grad_v[0, 2] - grad_v[2, 0], grad_v[1, 0] - grad_v[0, 1], ] ) else: assert False, "This dimensionality is not supported by laplacian_fos" pde_3 = curl_v w1 = 1 w2 = 1 / spacing w3 = 1 / spacing return jnp.concatenate([w1 * pde_1, w2 * pde_2, w3 * pde_3]).flatten() return d_alembert_fun
[docs]def linear_elasticity(youngs_mod_fun, poisson_ratio_fun, mode, volume_load_fun=None): """ Constructs the strong form of linear elasticity in Voigt notation, displacement based. This function supports 2D plain strain, 2D plain stress, and 3D elasticity problems. Args: youngs_mod_fun (function): Function to compute Young's modulus given spatial coordinates (and optional settings). poisson_ratio_fun (function): Function to compute Poisson's ratio given spatial coordinates (and optional settings). mode (str): Specifies the mode of the elasticity problem. It can be 'plain strain', 'plain stress', or '3d'. volume_load_fun (function, optional): Function to compute volume load given spatial coordinates. Default is None. Returns: function: A function that evaluates the strong form PDE for linear elasticity. Notes: - This models works for DOFs as a jnp.ndarray. """ if mode == "plain strain" or mode == "plain stress": n_dim = 2 elif mode == "3d": n_dim = 3 else: assert False, "'mode' for linear elasticity not properly set." def lin_el_fun(x, ansatz, settings, static_settings, int_point_number, set): def get_stress(xf): # Kinematics displacement = ansatz displ_grad = jax.jacfwd(ansatz)(xf) strain = (1 / 2) * (displ_grad + displ_grad.transpose()) if n_dim == 2: strain_voigt = jnp.asarray( [strain[0, 0], strain[1, 1], 2 * strain[0, 1]] ) else: strain_voigt = jnp.asarray( [ strain[0, 0], strain[1, 1], strain[2, 2], 2 * strain[0, 1], 2 * strain[0, 2], 2 * strain[1, 2], ] ) # Constiutive model if len(signature(poisson_ratio_fun).parameters) == 1: nu = poisson_ratio_fun(x) else: nu = poisson_ratio_fun(x, settings) if len(signature(youngs_mod_fun).parameters) == 1: Em = youngs_mod_fun(x) else: Em = youngs_mod_fun(x, settings) match mode: case "plain strain": mu = Em / (2 * (1 + nu)) c1 = 1 - 2 * nu c2 = 1 - nu coeff = 2 * mu / c1 material_tangent = coeff * jnp.asarray( [[c2, nu, 0], [nu, c2, 0], [0, 0, c1]] ) case "plain stress": coeff = Em / (1 - nu**2) material_tangent = coeff * jnp.asarray( [[1, nu, 0], [nu, 1, 0], [0, 0, (1 - nu) / 2]] ) case "3d": coeff = Em / (1 + nu) c1 = 1 - 2 * nu c2 = (1 - nu) / c1 c3 = nu / c1 c4 = 1 / 2 material_tangent = coeff * jnp.asarray( [ [c2, c3, c3, 0, 0, 0], [c3, c2, c3, 0, 0, 0], [c3, c3, c2, 0, 0, 0], [0, 0, 0, c4, 0, 0], [0, 0, 0, 0, c4, 0], [0, 0, 0, 0, 0, c4], ] ) return material_tangent @ strain_voigt # stress_voigt = get_stress(x) # Internal forces grad_stress = jax.jacfwd(get_stress)(x) if n_dim == 2: div_stress = jnp.asarray( [ grad_stress[0, 0] + grad_stress[2, 1], grad_stress[2, 0] + grad_stress[1, 1], ] ) else: div_stress = jnp.asarray( [ grad_stress[0, 0] + grad_stress[3, 1] + grad_stress[4, 2], grad_stress[3, 0] + grad_stress[1, 1] + grad_stress[5, 2], grad_stress[4, 0] + grad_stress[5, 1] + grad_stress[2, 2], ] ) # Balance equation if volume_load_fun == None: return div_stress else: return div_stress + volume_load_fun(x) return lin_el_fun
[docs]def linear_elasticity_weak( youngs_mod_fun, poisson_ratio_fun, mode, volume_load_fun=None ): """ Constructs the weak form of linear elasticity in Voigt notation, displacement based. This function supports 2D plain strain, 2D plain stress, and 3D elasticity problems. Args: youngs_mod_fun (function): Function to compute Young's modulus given spatial coordinates (and optional settings). poisson_ratio_fun (function): Function to compute Poisson's ratio given spatial coordinates (and optional settings). mode (str): Specifies the mode of the elasticity problem. It can be 'plain strain', 'plain stress', or '3d'. volume_load_fun (function, optional): Function to compute volume load given spatial coordinates. Default is None. Returns: function: A function that evaluates the weak form PDE for linear elasticity. Notes: - This models works for DOFs as a jnp.ndarray. """ if mode == "plain strain" or mode == "plain stress": n_dim = 2 elif mode == "3d": n_dim = 3 else: assert False, "'mode' for linear elasticity not properly set." def lin_el_fun( x, ansatz, test_ansatz, settings, static_settings, int_point_number, set ): def get_stress(xf): # Kinematics displacement = ansatz displ_grad = jax.jacfwd(ansatz)(xf) strain = (1 / 2) * (displ_grad + displ_grad.transpose()) if n_dim == 2: strain_voigt = jnp.asarray( [strain[0, 0], strain[1, 1], 2 * strain[0, 1]] ) else: strain_voigt = jnp.asarray( [ strain[0, 0], strain[1, 1], strain[2, 2], 2 * strain[0, 1], 2 * strain[0, 2], 2 * strain[1, 2], ] ) # Constiutive model if len(signature(poisson_ratio_fun).parameters) == 1: nu = poisson_ratio_fun(x) else: nu = poisson_ratio_fun(x, settings) if len(signature(youngs_mod_fun).parameters) == 1: Em = youngs_mod_fun(x) else: Em = youngs_mod_fun(x, settings) match mode: case "plain strain": mu = Em / (2 * (1 + nu)) c1 = 1 - 2 * nu c2 = 1 - nu coeff = 2 * mu / c1 material_tangent = coeff * jnp.asarray( [[c2, nu, 0], [nu, c2, 0], [0, 0, c1]] ) case "plain stress": coeff = Em / (1 - nu**2) material_tangent = coeff * jnp.asarray( [[1, nu, 0], [nu, 1, 0], [0, 0, (1 - nu) / 2]] ) case "3d": coeff = Em / (1 + nu) c1 = 1 - 2 * nu c2 = (1 - nu) / c1 c3 = nu / c1 c4 = 1 / 2 material_tangent = coeff * jnp.asarray( [ [c2, c3, c3, 0, 0, 0], [c3, c2, c3, 0, 0, 0], [c3, c3, c2, 0, 0, 0], [0, 0, 0, c4, 0, 0], [0, 0, 0, 0, c4, 0], [0, 0, 0, 0, 0, c4], ] ) return material_tangent @ strain_voigt stress_voigt = get_stress(x) # Test function virtual_disp = test_ansatz(x) virtual_disp_grad = jax.jacfwd(test_ansatz)(x) virtual_strain = (1 / 2) * (virtual_disp_grad + virtual_disp_grad.transpose()) if n_dim == 2: virtual_strain_voigt = jnp.asarray( [virtual_strain[0, 0], virtual_strain[1, 1], 2 * virtual_strain[0, 1]] ) else: virtual_strain_voigt = jnp.asarray( [ virtual_strain[0, 0], virtual_strain[1, 1], virtual_strain[2, 2], 2 * virtual_strain[0, 1], 2 * virtual_strain[0, 2], 2 * virtual_strain[1, 2], ] ) # Balance equation if volume_load_fun == None: return jnp.dot(stress_voigt, virtual_strain_voigt) else: return jnp.dot(stress_voigt, virtual_strain_voigt) - jnp.dot( volume_load_fun(x), virtual_disp ) return lin_el_fun
[docs]def linear_elasticity_fos( youngs_mod_fun, poisson_ratio_fun, mode, spacing, volume_load_fun=None ): """ Constructs the first order system (FOS) of linear elasticity in Voigt notation, displacement based. This function supports 2D plain strain and 3D elasticity problems. The formulation is based on the first order system described in https://doi.org/10.1137/S0036142902418357 with additional asymmetry penalization. Args: youngs_mod_fun (function): Function to compute Young's modulus given spatial coordinates (and optional settings). poisson_ratio_fun (function): Function to compute Poisson's ratio given spatial coordinates (and optional settings). mode (str): Specifies the mode of the elasticity problem. It can be 'plain strain' or '3d'. spacing (float): The spatial spacing parameter used to scale the augmented conditions. volume_load_fun (function, optional): Function to compute volume load given spatial coordinates. Default is None. Returns: function: A function that evaluates the first order system of PDEs for linear elasticity. The returned function computes the following: - The kinematic strain based on the displacement gradient. - The constitutive relation using the specified Young's modulus and Poisson's ratio. - The internal force balance equation, optionally including volume loads. - Weak enforcement of symmetry of the Cauchy stress tensor. Notes: - In 2D, the number of fields is 2 (displacement) + 4 (stress) = 6. - In 3D, the number of fields is 3 (displacement) + 9 (stress) = 12. - This function only supports 'plain strain' and '3d' modes. - This models works for DOFs as a jnp.ndarray. """ if mode == "plain strain": n_dim = 2 elif mode == "3d": n_dim = 3 else: assert False, "'mode' for linear elasticity not properly set." def lin_el_fun(x, ansatz, settings, static_settings, int_point_number, set): # Primary fields displacement = lambda t: ansatz(t)[:n_dim] stress = lambda t: ansatz(t)[n_dim:].reshape((n_dim, n_dim)) # Kinematics displ_grad = jax.jacfwd(displacement)(x) strain = (1 / 2) * (displ_grad + displ_grad.transpose()) # Constiutive model if len(signature(poisson_ratio_fun).parameters) == 1: nu = poisson_ratio_fun(x) else: nu = poisson_ratio_fun(x, settings) if len(signature(youngs_mod_fun).parameters) == 1: Em = youngs_mod_fun(x) else: Em = youngs_mod_fun(x, settings) mu = Em / (2 * (1 + nu)) # Shear modulus lam = Em * nu / ((1 + nu) * (1 - 2 * nu)) # First Lame constant # Complience tensor times stress sigma = stress(x) A_sigma = (1 / (2 * mu)) * ( sigma - (lam / (n_dim * lam + 2 * mu)) * jnp.trace(sigma) * jnp.identity(n_dim) ) # Second pde: constitutive model pde_2 = (A_sigma - strain).flatten() # Internal forces grad_stress = jax.jacfwd(stress)(x) if n_dim == 2: div_stress = jnp.asarray( [ grad_stress[0, 0, 0] + grad_stress[0, 1, 1], grad_stress[1, 0, 0] + grad_stress[1, 1, 1], ] ) else: div_stress = jnp.asarray( [ grad_stress[0, 0, 0] + grad_stress[0, 1, 1] + grad_stress[0, 2, 2], grad_stress[1, 0, 0] + grad_stress[1, 1, 1] + grad_stress[1, 2, 2], grad_stress[2, 0, 0] + grad_stress[2, 1, 1] + grad_stress[2, 2, 2], ] ) # First pde: Balance equation if volume_load_fun == None: pde_1 = div_stress else: pde_1 = div_stress + volume_load_fun(x) # Weak enforcement of symmetry of Cauchy stress pde_3 = ((1 / 2) * (stress(x) - stress(x).transpose())).flatten() # Return first order system w_1 = 1.0 w_2 = 1.0 w_3 = 10.0 / mu return jnp.concatenate([w_1 * pde_1, w_2 * pde_2, w_3 * pde_3]).flatten() # Return function for evaluating first order system with arguments x and ansatz return lin_el_fun
[docs]def neumann_weak(neumann_fun): """ Constructs the weak form of Neumann boundary conditions for the virtual work of surface tractions or heat inflow. This function can be used to impose Neumann boundary conditions in the weak form of problems such as linear momentum balance or heat conduction. Args: neumann_fun (callable): A function that computes the Neumann boundary condition (e.g., surface tractions or heat inflow) given spatial coordinates. This function can optionally take the settings as an additional argument. Returns: function A function that evaluates the weak form of the Neumann boundary condition. The returned function computes the virtual work of the Neumann boundary condition, which is the dot product of the virtual displacement (or temperature) and the Neumann boundary condition (e.g., surface tractions or heat inflow). Notes: - This models works for DOFs as a jnp.ndarray. """ def fun(x, ansatz, test_ansatz, settings, static_settings, int_point_number, set): # Test function virtual_disp = test_ansatz(x) # Tractions if len(signature(neumann_fun).parameters) == 1: tractions = neumann_fun(x) else: tractions = neumann_fun(x, settings) # Virtual work return -jnp.dot(virtual_disp, tractions) return fun
### Nonlinear equations
[docs]def burgers_equation_inviscid(): """ Constructs the inviscid Burgers' equation: du/dt + u * du/dx = 0. Args: None Returns: function: A function that evaluates the inviscid Burgers' equation. Notes: - first dimension: x - second dimension: t - This models works for DOFs as a jnp.ndarray. """ def pde_fun(x, ansatz, settings, static_settings, int_point_number, set): u = ansatz(x) u_x = jax.jacrev(ansatz)(x)[0] u_t = jax.jacrev(ansatz)(x)[1] return u_t + u * u_x return pde_fun
[docs]def hyperelastic_steady_state_fos( strain_energy_fun, youngs_mod_fun, poisson_ratio_fun, mode, spacing, volume_load_fun=None, ): """ Constructs a first-order PDE system for given strain energy function This function is based on equation (28) in https://doi.org/10.1002/nme.5951 and returns a function representing the first-order PDE system for steady-state hyperelasticity. Args: strain_energy_fun (function): Function that returns the strain energy given the deformation gradient and Lame parameters. youngs_mod_fun (function): Function to compute Young's modulus given spatial coordinates. poisson_ratio_fun (function): Function to compute Poisson's ratio given spatial coordinates. mode (str): Specifies the mode of the elasticity problem. It can be 'plain strain' or '3d'. spacing (float): The spatial spacing parameter used to scale the augmented conditions. volume_load_fun (function, optional): Function to compute volume load given spatial coordinates. Default is None. Returns: function: A function that evaluates the first-order PDE system for hyperelasticity. Notes: - The volume load function, if provided, must be volume-specific. - This models works for DOFs as a jnp.ndarray. """ def pde_fun(x, ansatz, settings, static_settings, int_point_number, set): # Initialization n_dim = x.shape[0] if len(signature(poisson_ratio_fun).parameters) == 1: nu = poisson_ratio_fun(x) else: nu = poisson_ratio_fun(x, settings) if len(signature(youngs_mod_fun).parameters) == 1: Em = youngs_mod_fun(x) else: Em = youngs_mod_fun(x, settings) # Lame parameters lam = Em * nu / ((1 + nu) * (1 - 2 * nu)) mu = Em / (2 * (1 + nu)) # Ansatz u = lambda t: ansatz(t)[:n_dim] # Displacement P = lambda t: ansatz(t)[n_dim:].reshape( (n_dim, n_dim) ) # 1. Piola-Kirchoff stress # Linear momentum balance Grad_P = jax.jacfwd(P)(x) Div_P = jnp.trace(Grad_P, axis1=1, axis2=2) # First pde: Balance equation if volume_load_fun == None: pde_1 = Div_P else: b = volume_load_fun(x) pde_1 = Div_P + b if mode == "plain strain": # Deformation gradient H = jax.jacfwd(u)(x) F = jnp.asarray( [ jnp.asarray([H[0, 0], H[0, 1], 0.0]), jnp.asarray([H[1, 0], H[1, 1], 0.0]), jnp.asarray([0.0, 0.0, 0.0]), ] ) + jnp.identity(3) # First Piola-Kirchhoff stress dpsi_dF = jax.jacrev(neo_hooke)(F, (lam, mu)) dpsi_dF_FT = (dpsi_dF @ jnp.transpose(F))[:n_dim, :n_dim] PFT = P(x)[:n_dim, :n_dim] @ jnp.transpose(F[:n_dim, :n_dim]) elif mode == "3d": # Deformation gradient F = jax.jacfwd(u)(x) + jnp.identity(3) # First Piola-Kirchhoff stress dpsi_dF = jax.jacrev(strain_energy_fun)(F, (lam, mu)) dpsi_dF_FT = dpsi_dF @ jnp.transpose(F) PFT = P(x) @ F.transpose() else: assert ( False ), "Hyperelastic model supports only 'plain strain' and '3d' modes." # Constitutive law pde_2 = ((1 / 2) * (PFT + jnp.transpose(PFT)) - dpsi_dF_FT).flatten() # Enforce symmetry of P \cdot F^T PFT_asym = (1 / 2) * (PFT - PFT.transpose()) pde_3 = PFT_asym.flatten() # Weightings w_1 = 1.0 w_2 = 1.0 / mu w_3 = 10.0 / mu return jnp.concatenate([w_1 * pde_1, w_2 * pde_2, w_3 * pde_3]) return pde_fun
[docs]def hyperelastic_steady_state_weak( strain_energy_fun, youngs_mod_fun, poisson_ratio_fun, mode, volume_load_fun=None ): """ Constructs the weak form of hyperelasticity for given strain energy function This function returns a function representing the weak form PDE system for steady-state hyperelasticity. Args: strain_energy_fun (function): Function that returns the strain energy given the deformation gradient and Lame parameters. youngs_mod_fun (function): Function to compute Young's modulus given spatial coordinates. poisson_ratio_fun (function): Function to compute Poisson's ratio given spatial coordinates. mode (str): Specifies the mode of the elasticity problem. It can be 'plain strain' or '3d'. volume_load_fun (function, optional): Function to compute volume load given spatial coordinates. Default is None. Returns: function: A function that evaluates the weak form PDE system for hyperelasticity. Notes: - The volume load function, if provided, must be volume-specific. - This models works for DOFs as a jnp.ndarray. """ def pde_fun( x, ansatz, test_ansatz, settings, static_settings, int_point_number, set ): # Initialization n_dim = x.shape[0] if len(signature(poisson_ratio_fun).parameters) == 1: nu = poisson_ratio_fun(x) else: nu = poisson_ratio_fun(x, settings) if len(signature(youngs_mod_fun).parameters) == 1: Em = youngs_mod_fun(x) else: Em = youngs_mod_fun(x, settings) # Lame parameters lam = Em * nu / ((1 + nu) * (1 - 2 * nu)) mu = Em / (2 * (1 + nu)) # (Virtual) displacement u = ansatz virt_u = test_ansatz # Deformation gradient if mode == "plain strain": H = jax.jacfwd(u)(x) F = jnp.asarray( [[H[0, 0], H[0, 1], 0.0], [H[1, 0], H[1, 1], 0.0], [0.0, 0.0, 0.0]] ) + jnp.identity(3) virt_H = jax.jacfwd(virt_u)(x) virt_F = jnp.asarray( [ [virt_H[0, 0], virt_H[0, 1], 0.0], [virt_H[1, 0], virt_H[1, 1], 0.0], [0.0, 0.0, 0.0], ] ) elif mode == "3d": F = jax.jacfwd(u)(x) + jnp.identity(3) virt_F = jax.jacfwd(virt_u)(x) else: assert ( False ), "Hyperelastic model supports only 'plain strain' and '3d' modes." # First Piola-Kirchhoff stress P = jax.jacrev(strain_energy_fun)(F, (lam, mu)) # Virtual work of internal forces PI_int = jnp.einsum("ij, ij -> ", P, virt_F) # PI_int = jnp.trace(P @ virt_F.transpose()) if volume_load_fun == None: return PI_int else: # Add virtual work of volume forces b = volume_load_fun(x) PI_ext = jnp.dot(b, virt_u) return PI_int - PI_ext return pde_fun
### Strain energy functions
[docs]def neo_hooke(F, param): """ Computes the strain energy for a neo-Hookean material model of Ciarlet type given the deformation gradient. Args: F (jnp.ndarray): Deformation gradient tensor. param (tuple): Lame parameters (lambda, mu). Returns: float: Strain energy value for the given deformation gradient. Notes: - The strain energy function is based on the Ciarlet type neo-Hookean model. - This model is suitable for large deformation analysis in hyperelastic materials. """ (lam, mu) = param # Right Cauchy-Green tensor tr_C = jnp.einsum("ij, ij", F, F) J = utility.matrix_det(F) ln_J = jnp.log(J) # Neo-Hookean strain energy function of Ciarlet type psi = (mu / 2) * (tr_C - 3 - 2 * ln_J) + (lam / 4) * (J**2 - 1 - 2 * ln_J) return psi
[docs]def isochoric_neo_hooke(F, mu): """ Computes the strain energy for an isochoric neo-Hookean material model. Args: F (jnp.ndarray): Deformation gradient tensor. mu (float): Shear modulus. Returns: float: Strain energy value for the given deformation gradient. """ # Right Cauchy-Green tensor tr_C = jnp.einsum("ij, ij", F, F) # Strain energy function return (mu / 2) * (tr_C - 3)
[docs]def linear_elastic_strain_energy(F, param): """ Computes the strain energy for a linear elastic material for small deformations. Args: F (jnp.ndarray): Deformation gradient tensor. param (tuple): Lame parameters (lambda, mu). Returns: float: Strain energy value for the given deformation gradient. """ (lam, mu) = param # Right Cauchy-Green tensor disp_grad = F - jnp.identity(F.shape[0]) strain = (disp_grad + disp_grad.T) / 2 return (lam / 2) * (jnp.trace(strain)) ** 2 + mu * jnp.trace(strain @ strain.T)
### User potentials/elements
[docs]def mixed_reference_domain_potential( integrand_fun, ansatz_fun, ref_int_coor, ref_int_weights, mapping_key ): """ Constructs a multi-field 'user potential' for integration of a potential in the reference configuration of finite elements. Works for DOFs as a dict of jnp.ndarrays. Args: integrand_fun (callable): Function that evaluates the integrand given the integration point, trial ansatz, settings, static settings, and element number. ansatz_fun (dict of callables): Functions that constructs the ansatz functions for each field. Keywords have to match with those defined in the DOFs dictionary. ref_int_coor (jnp.ndarray): Reference integration coordinates for the isoparametric element. ref_int_weights (jnp.ndarray): Reference integration weights for the isoparametric element. mapping_key (string): Which solution space to use for the isoparametric mapping. Returns: function A function that evaluates the integrand for the isoparametric element. The returned function has the following parameters: - fI (dict of jnp.ndarray): Element nodal values. - xI (dict of jnp.ndarray): Element nodal coordinates. - elem_number (int): Element number. - settings (dict): Settings for the computation. - static_settings (dict): Static settings for the computation. Exemplary integrand function: --> def integrand_fun(x_int, ansatz_fun, settings, static_settings, elem_number, set): --> # Definition of custom functional --> x = ansatz_fun['physical coor'](x_int) # Physical coordinates via isoparametric mapping --> phi_fun = ansatz_fun['phi'] # Field approximation function with key words as defined in the DOFs dictionary --> phi = phi_fun(x_int) --> dphi_dx = jax.jacrev(phi_fun)(x_int) --> --> x_1 = x --> x_2 = x - jnp.array([1., 0.5]) --> source_term = 20 * (jnp.sin(10 * x_1 @ x_1) - jnp.cos(10 * x_2 @ x_2)) --> return (1/2) * dphi_dx @ dphi_dx - source_term * phi # Returns a scalar potential that is to be integrated in reference configuration Notes: This models works for DOFs as a dict of jnp.ndarrays. """ field_keys = list(ansatz_fun.keys()) def user_potential(fI, xI, elem_number, settings, static_settings, set): n_dim = xI[field_keys[0]].shape[-1] # Define solution spaces trial_ansatz = { key: ( lambda x, key=key: ansatz_fun[key]( x, xI[key], fI[key], settings, True, n_dim ) ) for key in field_keys } # Isoparametric mapping mapping = lambda x: ansatz_fun[mapping_key]( x, xI[mapping_key], xI[mapping_key], settings, False, n_dim ) trial_ansatz["physical coor"] = mapping # One Gauß point contribution def functional(x_int, w_int): # Evaluate functional discrete_functional_fun = integrand_fun( x_int, trial_ansatz, settings, static_settings, elem_number, set ) # Use ansatz for isoparametric mapping and compute determinant of Jacobian jacobian = jax.jacfwd(mapping)(x_int) det_jacobian = utility.matrix_det(jacobian) # Functional contribution return discrete_functional_fun * w_int * det_jacobian # Sum over Gauß points using vmap (faster but needs more memory) return jax.vmap(functional, (0, 0), 0)(ref_int_coor, ref_int_weights).sum() # # Sum over Gauß points with jax.lax.map # return jax.lax.map(lambda i: functional(ref_int_coor[i], ref_int_weights[i]), jnp.arange(ref_int_coor.shape[0])).sum() return user_potential
[docs]def mixed_reference_domain_residual( integrand_fun, ansatz_fun, ref_int_coor, ref_int_weights, mapping_key ): """ Constructs a multi-field 'user residual' for integration of a weak form in the reference configuration of finite elements. Works for DOFs as a dict of jnp.ndarrays. Args: integrand_fun (callable): Function that evaluates the integrand given the integration point, trial ansatz, settings, static settings, and element number. ansatz_fun (dict of callables): Functions that constructs the ansatz functions for each field. Keywords have to match with those defined in the DOFs dictionary. ref_int_coor (jnp.ndarray): Reference integration coordinates for the isoparametric element. ref_int_weights (jnp.ndarray): Reference integration weights for the isoparametric element. mapping_key (string): Which solution space to use for the isoparametric mapping. Returns: function A function that evaluates the integrand for the isoparametric element. The returned function has the following parameters: - fI (dict of jnp.ndarray): Element nodal values. - xI (dict of jnp.ndarray): Element nodal coordinates. - elem_number (int): Element number. - settings (dict): Settings for the computation. - static_settings (dict): Static settings for the computation. Notes: This models works for DOFs as a dict of jnp.ndarrays. """ field_keys = list(ansatz_fun.keys()) def user_residual(fI, xI, elem_number, settings, static_settings, set): n_dim = xI[field_keys[0]].shape[-1] def integrated_weak_form(test_fI): # Define solution spaces trial_ansatz = { key: ( lambda x, key=key: ansatz_fun[key]( x, xI[key], fI[key], settings, True, n_dim ) ) for key in field_keys } test_ansatz = { key: ( lambda x, key=key: ansatz_fun[key]( x, xI[key], test_fI[key], settings, True, n_dim ) ) for key in field_keys } # Isoparametric mapping mapping = lambda x: ansatz_fun[mapping_key]( x, xI[mapping_key], xI[mapping_key], settings, False, n_dim ) trial_ansatz["physical coor"] = mapping # One Gauß point contribution def functional(x_int, w_int): weak_form = integrand_fun( x_int, trial_ansatz, test_ansatz, settings, static_settings, elem_number, set, ) # Use ansatz for isoparametric mapping and compute determinant of Jacobian jacobian = jax.jacfwd(mapping)(x_int) det_jacobian = utility.matrix_det(jacobian) # Weak form contribution return weak_form * w_int * det_jacobian # Sum over Gauß points return jax.vmap(functional, (0, 0), 0)(ref_int_coor, ref_int_weights).sum( axis=0 ) residual = jax.jacrev(integrated_weak_form)(fI) return residual return user_residual
[docs]def mixed_reference_surface_potential( integrand_fun, ansatz_fun, ref_int_coor, ref_int_weights, mapping_key ): """ Constructs a multi-field 'user potential' for integration of a potential in the reference configuration of surface elements. Works for DOFs as a dict of jnp.ndarrays. Args: integrand_fun (callable): Function that evaluates the integrand given the integration point, trial ansatz, settings, static settings, and element number. ansatz_fun (dict of callables): Functions that constructs the ansatz functions for each field. Keywords have to match with those defined in the DOFs dictionary. ref_int_coor (jnp.ndarray): Reference integration coordinates for the isoparametric element. ref_int_weights (jnp.ndarray): Reference integration weights for the isoparametric element. mapping_key (string): Which solution space to use for the isoparametric mapping. Returns: function A function that evaluates the integrand for the isoparametric element. The returned function has the following parameters: - fI (dict of jnp.ndarray): Element nodal values. - xI (dict of jnp.ndarray): Element nodal coordinates. - elem_number (int): Element number. - settings (dict): Settings for the computation. - static_settings (dict): Static settings for the computation. Notes: This models works for DOFs as a dict of jnp.ndarrays. """ field_keys = list(ansatz_fun.keys()) def user_potential(fI, xI, elem_number, settings, static_settings, set): n_dim = xI[field_keys[0]].shape[-1] # Define solution spaces trial_ansatz = { key: ( lambda x, key=key: ansatz_fun[key]( x, xI[key], fI[key], settings, True, n_dim ) ) for key in field_keys } # Isoparametric mapping mapping = lambda x: ansatz_fun[mapping_key]( x, xI[mapping_key], xI[mapping_key], settings, False, n_dim - 1 ) # Shape functions in n_dim-1 dimensions trial_ansatz["physical coor"] = mapping # One Gauß point contribution def functional(x_int, w_int): integrand = integrand_fun( x_int, trial_ansatz, settings, static_settings, elem_number, set ) # Use ansatz for isoparametric mapping and compute determinant of Jacobian jacobian = jax.jacfwd(mapping)(x_int) match n_dim: case 2: G_1 = jnp.concatenate([jacobian, jnp.asarray([0])]) E_3 = jnp.asarray([0, 0, 1]) scaling = jnp.linalg.norm(jnp.cross(G_1, E_3)) case 3: G_1 = jacobian[:, 0] G_2 = jacobian[:, 1] scaling = jnp.linalg.norm(jnp.cross(G_1, G_2)) case _: raise ValueError( "Surface element supports currently only lines in 2d and 2d-surfaces in 3d." ) # Functional contribution return integrand * w_int * scaling # Sum over Gauß points potential = jax.vmap(functional, (0, 0), 0)(ref_int_coor, ref_int_weights).sum( axis=0 ) return potential return user_potential
[docs]def mixed_reference_surface_residual( integrand_fun, ansatz_fun, ref_int_coor, ref_int_weights, mapping_key ): """ Constructs a multi-field 'user residual' for integration of a weak form in the reference configuration of surface elements. Works for DOFs as a dict of jnp.ndarrays. Args: integrand_fun (callable): Function that evaluates the integrand given the integration point, trial ansatz, settings, static settings, and element number. ansatz_fun (dict of callables): Functions that constructs the ansatz functions for each field. Keywords have to match with those defined in the DOFs dictionary. ref_int_coor (jnp.ndarray): Reference integration coordinates for the isoparametric element. ref_int_weights (jnp.ndarray): Reference integration weights for the isoparametric element. mapping_key (string): Which solution space to use for the isoparametric mapping. Returns: function A function that evaluates the integrand for the isoparametric element. The returned function has the following parameters: - fI (dict of jnp.ndarray): Element nodal values. - xI (dict of jnp.ndarray): Element nodal coordinates. - elem_number (int): Element number. - settings (dict): Settings for the computation. - static_settings (dict): Static settings for the computation. Notes: This models works for DOFs as a dict of jnp.ndarrays. """ field_keys = list(ansatz_fun.keys()) def user_residual(fI, xI, elem_number, settings, static_settings, set): n_dim = xI[field_keys[0]].shape[-1] def integrated_weak_form(test_fI): # Define solution spaces trial_ansatz = { key: ( lambda x, key=key: ansatz_fun[key]( x, xI[key], fI[key], settings, True, n_dim ) ) for key in field_keys } test_ansatz = { key: ( lambda x, key=key: ansatz_fun[key]( x, xI[key], test_fI[key], settings, True, n_dim ) ) for key in field_keys } # Isoparametric mapping mapping = lambda x: ansatz_fun[mapping_key]( x, xI[mapping_key], xI[mapping_key], settings, False, n_dim - 1 ) # Shape functions in n_dim-1 dimensions trial_ansatz["physical coor"] = mapping # One Gauß point contribution def functional(x_int, w_int): weak_form = integrand_fun( x_int, trial_ansatz, test_ansatz, settings, static_settings, elem_number, set, ) # Use ansatz for isoparametric mapping and compute determinant of Jacobian jacobian = jax.jacfwd(mapping)(x_int) match n_dim: case 2: G_1 = jnp.concatenate([jacobian, jnp.asarray([0])]) E_3 = jnp.asarray([0, 0, 1]) scaling = jnp.linalg.norm(jnp.cross(G_1, E_3)) case 3: G_1 = jacobian[:, 0] G_2 = jacobian[:, 1] scaling = jnp.linalg.norm(jnp.cross(G_1, G_2)) case _: raise ValueError( "Surface element supports currently only lines in 2d and 2d-surfaces in 3d." ) # Weak form contribution return weak_form * w_int * scaling # Sum over Gauß points return jax.vmap(functional, (0, 0), 0)(ref_int_coor, ref_int_weights).sum( axis=0 ) residual = jax.jacrev(integrated_weak_form)(fI) return residual return user_residual
[docs]def isoparametric_domain_integrate_potential( integrand_fun, ansatz_fun, ref_int_coor, ref_int_weights, initial_config=True ): """ Constructs a local integrand fun (user potential) for integration of functions in the reference configuration of isoparametric elements. Works for DOFs as a jnp.ndarray. Args: integrand_fun (callable): Function that evaluates the integrand given the integration point, trial ansatz, settings, static settings, and element number. ansatz_fun (callable): Function that constructs the ansatz functions. ref_int_coor (jnp.ndarray): Reference integration coordinates for the isoparametric element. ref_int_weights (jnp.ndarray): Reference integration weights for the isoparametric element. initial_config (bool, optional): If True, use the initial configuration for isoparametric mapping. Default is True. Returns: function A function that evaluates the integrand for the isoparametric element. The returned function has the following parameters: - fI (jnp.ndarray): Element nodal values. - xI (jnp.ndarray): Element nodal coordinates. - elem_number (int): Element number. - settings (dict): Settings for the computation. - static_settings (dict): Static settings for the computation. Notes: - This models works for DOFs as a jnp.ndarray. """ def user_potential(fI, xI, elem_number, settings, static_settings, set): n_dofs = fI.flatten().shape[0] n_dim = xI.shape[-1] # One Gauß point contribution def functional(x_int, w_int): def integrated_weak_form(trial_fI): # Define trial and test function ansatz trial_ansatz = lambda x: ansatz_fun( x, xI, trial_fI, settings, True, n_dim ) # Evaluate functional discrete_functional_fun = integrand_fun( x_int, trial_ansatz, settings, static_settings, elem_number, set ) # Use ansatz for isoparametric mapping and compute determinant of Jacobian if initial_config: mapping = lambda x: ansatz_fun(x, xI, xI, settings, False, n_dim) else: # Take first n_dim fields as displacements uI = fI[:, : xI.shape[-1]] mapping = lambda x: ansatz_fun( x, xI + uI, xI + uI, settings, overwrite_diff=False ) jacobian = jax.jacfwd(mapping)(x_int) det_jacobian = utility.matrix_det(jacobian) return w_int * det_jacobian * discrete_functional_fun # Residual return integrated_weak_form(fI) # Sum over Gauß points functional_contrib = jax.vmap(functional, (0, 0), 0)( ref_int_coor, ref_int_weights ) return functional_contrib.sum(axis=0) return user_potential
[docs]def isoparametric_domain_element_galerkin( weak_form_fun, ansatz_fun, ref_int_coor, ref_int_weights, initial_config=True ): """ Constructs an isoparametric domain element for Galerkin methods using a given weak form function and DOFs as jnp.ndarray. This function returns a user element function that evaluates the residual and tangent matrix for an isoparametric element based on the provided weak form and ansatz functions. The function can e.g. be used in combination with the weak_form_fun hyperelastic_steady_state_weak. Args: weak_form_fun (callable): Function that evaluates the weak form given the integration point, trial ansatz, test ansatz, settings, static settings, and element number. ansatz_fun (callable): Function that constructs the ansatz (trial and test) functions. ref_int_coor (jnp.ndarray): Reference integration coordinates for the isoparametric element. ref_int_weights (jnp.ndarray): Reference integration weights for the isoparametric element. initial_config (bool, optional): If True, use the initial configuration for isoparametric mapping. Default is True. Returns: function A function that evaluates the residual and tangent matrix for the isoparametric element. Notes: - The returned function can be used to evaluate the residual and tangent matrix for the element. - The mode parameter in the returned function can be 'residual' or 'tangent'. The returned function has the following parameters: - fI (jnp.ndarray): Element nodal values. - xI (jnp.ndarray): Element nodal coordinates. - elem_number (int): Element number. - settings (dict): Settings for the computation. - static_settings (dict): Static settings for the computation. - mode (str): Specifies whether to compute the 'residual' or 'tangent'. Notes: - This models works for DOFs as a jnp.ndarray. """ def user_element(fI, xI, elem_number, settings, static_settings, mode, set): n_dofs = fI.flatten().shape[0] n_dim = xI.shape[-1] # One Gauß point contribution def residual_and_tangent(x_int, w_int): def integrated_weak_form(test_fI, trial_fI): # Define trial and test function ansatz trial_ansatz = lambda x: ansatz_fun( x, xI, trial_fI, settings, True, n_dim ) test_ansatz = lambda x: ansatz_fun( x, xI, test_fI, settings, True, n_dim ) # Evaluate weak form discrete_weak_form = weak_form_fun( x_int, trial_ansatz, test_ansatz, settings, static_settings, elem_number, set, ) # Use ansatz for isoparametric mapping and compute determinant of Jacobian if initial_config: mapping = lambda x: ansatz_fun(x, xI, xI, settings, False, n_dim) else: # Take first n_dim fields as displacements uI = fI[:, : xI.shape[-1]] mapping = lambda x: ansatz_fun( x, xI + uI, xI + uI, settings, overwrite_diff=False ) jacobian = jax.jacfwd(mapping)(x_int) det_jacobian = utility.matrix_det(jacobian) return w_int * det_jacobian * discrete_weak_form # Residual residual_fun = jax.jacrev(integrated_weak_form, argnums=0) residual = residual_fun(fI, fI).flatten() # Tangent tangent_fun = jax.jacfwd(residual_fun, argnums=1) tangent = tangent_fun(fI, fI).reshape(n_dofs, n_dofs) return (residual, tangent) # Sum over Gauß points (residual_contrib, tangent_contrib) = jax.vmap( residual_and_tangent, (0, 0), (0, 0) )(ref_int_coor, ref_int_weights) residual = residual_contrib.sum(axis=0) tangent = tangent_contrib.sum(axis=0) if mode == "residual": return residual elif mode == "tangent": return tangent else: raise ValueError("User elem mode has to be 'residual' or 'tangent'.") return user_element
[docs]def isoparametric_surface_element_galerkin( weak_form_fun, ansatz_fun, ref_int_coor, ref_int_weights, tangent_contributions, initial_config=True, ): """ Constructs an isoparametric surface element for Galerkin methods using a given weak form function and DOFs as jnp.ndarray. This function returns a user element function that evaluates the residual and tangent matrix for an isoparametric surface element based on the provided weak form and ansatz functions. The function can, for example, be used in combination with the weak_form_fun neumann_weak. Args: weak_form_fun (callable): Function that evaluates the weak form given the integration point, trial ansatz, test ansatz, settings, static settings, and element number. ansatz_fun (callable): Function that constructs the ansatz (trial and test) functions. ref_int_coor (jnp.ndarray): Reference integration coordinates for the isoparametric element. ref_int_weights (jnp.ndarray): Reference integration weights for the isoparametric element. tangent_contributions (bool): If True, compute the tangent contributions; otherwise, return zero tangent. initial_config (bool, optional): If True, use the initial configuration for isoparametric mapping. Default is True. Returns: callable: A function that evaluates the residual and tangent matrix for the isoparametric surface element. Notes: - Currently, the weak form function must not contain derivatives wrt. the primary fields. - The returned function can be used to evaluate the residual and tangent matrix for the element. - The mode parameter in the returned function can be 'residual' or 'tangent'. The returned function has the following parameters: - fI (jnp.ndarray): Element nodal values. - xI (jnp.ndarray): Element nodal coordinates. - elem_number (int): Element number. - settings (dict): Settings for the computation. - static_settings (dict): Static settings for the computation. - mode (str): Specifies whether to compute the 'residual' or 'tangent'. Notes: - This models works for DOFs as a jnp.ndarray. """ def user_element(fI, xI, elem_number, settings, static_settings, mode, set): n_dofs = fI.flatten().shape[0] n_dim = xI.shape[-1] # One Gauß point contribution def residual_and_tangent(x_int, w_int): def integrated_weak_form(test_fI, trial_fI): # Define trial and test function ansatz trial_ansatz = lambda x: ansatz_fun( x, xI, trial_fI, settings, False, n_dim - 1 ) test_ansatz = lambda x: ansatz_fun( x, xI, test_fI, settings, False, n_dim - 1 ) # Evaluate weak form discrete_weak_form = weak_form_fun( x_int, trial_ansatz, test_ansatz, settings, static_settings, elem_number, set, ) # Compute the Jacobian if initial_config: mapping = lambda x: ansatz_fun( x, xI, xI, settings, False, n_dim - 1 ) else: # Take first n_dim fields as displacements uI = fI[:, : xI.shape[-1]] mapping = lambda x: ansatz_fun( x, xI + uI, xI + uI, settings, overwrite_diff=False ) jacobian = jax.jacrev(mapping)(x_int) match n_dim: case 2: G_1 = jnp.concatenate([jacobian, jnp.asarray([0])]) E_3 = jnp.asarray([0, 0, 1]) scaling = jnp.linalg.norm(jnp.cross(G_1, E_3)) case 3: G_1 = jacobian[:, 0] G_2 = jacobian[:, 1] scaling = jnp.linalg.norm(jnp.cross(G_1, G_2)) case _: raise ValueError( "Surface element supports currently only lines in 2d and 2d-surfaces in 3d." ) return w_int * scaling * discrete_weak_form # Residual residual_fun = jax.jacrev(integrated_weak_form, argnums=0) residual = residual_fun(fI, fI).flatten() # Tangent if tangent_contributions: tangent_fun = jax.jacfwd(residual_fun, argnums=1) tangent = tangent_fun(fI, fI).reshape(n_dofs, n_dofs) else: tangent = jnp.zeros((n_dofs, n_dofs), dtype=jnp.float64) return (residual, tangent) # Sum over Gauß points (residual_contrib, tangent_contrib) = jax.vmap( residual_and_tangent, (0, 0), (0, 0) )(ref_int_coor, ref_int_weights) residual = residual_contrib.sum(axis=0) tangent = tangent_contrib.sum(axis=0) if mode == "residual": return residual elif mode == "tangent": return tangent else: raise ValueError("User elem mode has to be 'residual' or 'tangent'.") return user_element
## Functions for coupling with DAE solver
[docs]def mixed_reference_domain_residual_time(integrand_fun, ansatz_fun, ref_int_coor, ref_int_weights, mapping_key): """ Constructs a multi-field 'user residual' for time-dependent weak forms. Args: integrand_fun (callable): Function evaluating the integrand given: x_int, trial_ansatz (callable with args = (x, t)), test_ansatz (callable with args = (x,)), settings, static_settings, elem_number, set. ansatz_fun (dict of callables): Each function must have signature: f(x, t, xI, fI, settings, is_trial, n_dim) ref_int_coor (jnp.ndarray): Reference integration coordinates. ref_int_weights (jnp.ndarray): Integration weights. mapping_key (str): Key for the isoparametric mapping. Returns: A function with signature: user_residual(fI, xI, elem_number, settings, static_settings, set) that computes the element residual. """ keys = ansatz_fun.keys() def user_residual(fI, xI, elem_number, settings, static_settings, set): n_dim = xI[next(iter(keys))].shape[-1] def integrated_weak_form(test_fI): trial_ansatz = {} test_ansatz = {} for key in keys: trial_ansatz[key] = lambda x, t, key=key: ansatz_fun[key](x, xI[key], fI(t)[key], settings, True, n_dim) test_ansatz[key] = lambda x, key=key: ansatz_fun[key](x, xI[key], test_fI[key], settings, True, n_dim) # Isoparametric mapping @jax.jit def mapping(x): return ansatz_fun[mapping_key](x, xI[mapping_key], xI[mapping_key], settings, False, n_dim) trial_ansatz['physical coor'] = mapping def functional(x_int, w_int, int_num): num_args = len(signature(integrand_fun).parameters) match num_args: case 7: weak_form = integrand_fun(x_int, trial_ansatz, test_ansatz, settings, static_settings, elem_number, set) case 8: weak_form = integrand_fun(x_int, trial_ansatz, test_ansatz, settings, static_settings, elem_number, int_num, set) # Calculate the jacobian and its determinant. jacobian = jax.jacfwd(mapping)(x_int) det_jacobian = utility.matrix_det(jacobian) return weak_form * w_int * det_jacobian # Numerical integration int_point_numbers = jnp.arange(ref_int_coor.shape[0]) return jax.vmap(functional, (0, 0, 0), 0)(ref_int_coor, ref_int_weights, int_point_numbers).sum(axis=0) # Get residual by taking the derivative of the weak form wrt the test functions nodal values. test_fI = utility.dict_zeros_like(fI(0.)) residual = jax.jacrev(integrated_weak_form)(test_fI) return residual return user_residual
[docs]def mixed_reference_domain_int_var_updates(local_int_var_updates_fun, ansatz_fun, ref_int_coor, ref_int_weights, mapping_key): """ Similar as mixed_reference_domain_residual_time, but generates a function that doesn't compute the residual, but the internal variables based on the local_int_var_updates_fun. """ keys = ansatz_fun.keys() def int_var_fun(fI, xI, elem_number, settings, static_settings, set): n_dim = xI[next(iter(keys))].shape[-1] trial_ansatz = {} for key in keys: trial_ansatz[key] = lambda x, t, key=key: ansatz_fun[key](x, xI[key], fI(t)[key], settings, True, n_dim) # Isoparametric mapping @jax.jit def mapping(x): return ansatz_fun[mapping_key](x, xI[mapping_key], xI[mapping_key], settings, False, n_dim) trial_ansatz['physical coor'] = mapping def local_fun(x_int, int_num): return local_int_var_updates_fun(x_int, trial_ansatz, settings, static_settings, elem_number, int_num, set) int_point_numbers = jnp.arange(ref_int_coor.shape[0]) return jax.vmap(local_fun, (0, 0), 0)(ref_int_coor, int_point_numbers) return int_var_fun
### Time integration procedures def forward_backward_euler_weak(inertia_coeff_fun): """ Constructs a weak form time discretization function using the Forward or Backward Euler method. Deprecated: This function is deprecated and may be removed in future versions. The new time integration procedures are now implemented in the dae module. For an examplary use, see the examples with explicit and implicit time integration. Args: inertia_coeff_fun (callable): Function to compute the inertia coefficient given spatial coordinates and settings. Returns: callable: A function that evaluates the weak form time discretization for the PDE. The returned function has the following parameters: - x (jnp.ndarray): Spatial coordinates at the integration point. - ansatz (callable): Ansatz function representing the primary field at time step n+1. - test_ansatz (callable): Test ansatz function representing the test function. - settings (dict): Settings for the computation, including: - 'dofs n': Degrees of freedom at the previous time step. - 'connectivity': Connectivity information for the elements. - 'time increment': Time increment delta_t. - static_settings (dict): Static settings for the computation. - int_point_number (int): Integration point number. - set: Number of domain Notes: - The primary field is evaluated at both time steps n and n+1. - The time derivative is computed using the Forward or Backward Euler method, depending on the use case. - The inertia coefficient is used to scale the time derivative term. - This models works for DOFs as a jnp.ndarray. """ def time_discretization_fun( x, ansatz, test_ansatz, settings, static_settings, int_point_number, set ): # Test function test_function = test_ansatz(x) # Warning if it was defined in static_settings assert "connectivity" not in static_settings, \ "'connectivity' has been moved to 'settings' in order to reduce compile time. \ Further, you should not transform it to a tuple of tuples anymore." # Primary field evaluated at n and n+1 dofs_n = settings["dofs n"] neighbor_list = settings["connectivity"][set][int_point_number] local_dofs_n = dofs_n[neighbor_list] ansatz_n = lambda x: solution_structures.solution_structure( x, int_point_number, local_dofs_n, settings, static_settings, set ) field_n = ansatz_n(x) field = ansatz(x) # Forward/Backward Euler delta_t = settings["time increment"] time_derivative = (field - field_n) / delta_t # E.g. heat capacity or whatever is multiplied with first order time derivative inertia_coeff = inertia_coeff_fun(x, settings) return -inertia_coeff * jnp.dot(test_function, time_derivative) return time_discretization_fun def central_differences(density_fun, damping): """ Constructs weak form contributions for acceleration (and optional damping) term based on central differences. Deprecated: This function is deprecated and may be removed in future versions. The new time integration procedures are now implemented in the dae module. Args: density_fun (callable): Function to compute the density coefficient given spatial coordinates and settings. damping (bool): Whether to turn on damping. Damping requires settings['damping coefficient'] to be set. Returns: callable: A function that evaluates the weak form contributions. The returned function has the following parameters: - x (jnp.ndarray): Spatial coordinates at the integration point. - ansatz (callable): Ansatz function representing the primary field at time step n+1. - test_ansatz (callable): Test ansatz function representing the test function. - settings (dict): Settings for the computation, including: - 'dofs n': Degrees of freedom at the previous time step. - 'connectivity': Connectivity information for the elements. - 'time increment': Time increment delta_t. - static_settings (dict): Static settings for the computation. - int_point_number (int): Integration point number. - set: Number of domain Notes: - Requires settings['dofs n], settings['dofs n-1'] and settings['time increment']. - This models works for DOFs as a jnp.ndarray. """ def weak_form_fun(x, ansatz, test_ansatz, settings, static_settings, int_point_number, set): # Test function test_function = test_ansatz(x) # Warning if it was defined in static_settings assert "connectivity" not in static_settings, \ "'connectivity' has been moved to 'settings' in order to reduce compile time. \ Further, you should not transform it to a tuple of tuples anymore." # Primary field evaluated at n-1, n and n+1 dofs_n = settings["dofs n"] dofs_n_min_1 = settings["dofs n-1"] neighbor_list = settings["connectivity"][set][int_point_number] ansatz_n = lambda x: solution_structures.solution_structure(x, int_point_number, dofs_n[neighbor_list], settings, static_settings, set) ansatz_n_min_1 = lambda x: solution_structures.solution_structure( x, int_point_number, dofs_n_min_1[neighbor_list], settings, static_settings, set, ) u_n = ansatz_n(x) u_n_min_1 = ansatz_n_min_1(x) u_n_plus_1 = ansatz(x) # Central differences delta_t = settings["time increment"] a_n = (u_n_plus_1 - 2 * u_n - u_n_min_1) / (delta_t**2) v_n = (u_n_plus_1 - u_n_min_1) / (2 * delta_t) # E.g. heat capacity or whatever is multiplied with first order time derivative density = density_fun(x, settings) if damping: damping_coeff = settings["damping coefficient"] return density * jnp.dot(test_function, a_n + damping_coeff * v_n) else: return density * jnp.dot(test_function, a_n) return weak_form_fun ### Convenience functions for modelling
[docs]@utility.jit_with_docstring(inline=True) def aug_lag_potential(lag, constr, eps): """ Augmented Lagrangian potential function for an inequality constraint (constr >= 0). This function constructs a C1 continuous augmented Lagrangian potential that can be added to a potential function to enforce the inequality constraint. It is analogous to adding a term 'lambda * g' when imposing an equality constraint (g = 0). For maximization problems, the potential should be subtracted. Args: lag: Scalar Lagrange multiplier. constr: Scalar evaluated constraint. eps: Scalar numerical parameter controlling the conditioning. Returns: The value of the augmented Lagrangian potential. """ aug_lag = lag + constr * eps condlist = [aug_lag <= 0, aug_lag > 0] funclist = [constr * (aug_lag - eps / 2 * constr), -(aug_lag - eps * constr)**2 / (2 * eps)] return jnp.piecewise(aug_lag, condlist, funclist)
[docs]@utility.jit_with_docstring(inline=True) def kelvin_mandel_extract(a): """ Converts a 6-component Kelvin-Mandel representation vector into a 3x3 symmetric matrix. In the Kelvin-Mandel notation, the off-diagonal components are scaled by 1/sqrt(2) to preserve norm equivalence between the tensor and its vector representation. This function extracts the full symmetric matrix from its compact vector representation. Args: a: 1-dimensional array-like with 6 elements corresponding to the Kelvin-Mandel representation of a 3x3 symmetric tensor. The expected ordering is: [a11, a22, a33, a12, a13, a23] Returns: A 3x3 jnp.ndarray representing the symmetric tensor in matrix form. """ one_sqrt_2 = 1 / jnp.sqrt(2.) A12 = one_sqrt_2 * a[3] A13 = one_sqrt_2 * a[4] A23 = one_sqrt_2 * a[5] return jnp.asarray([a[0], A12, A13, A12, a[1], A23, A13, A23, a[2]]).reshape((3, 3))