# dae.py
# Copyright (C) 2025 Tobias Bode
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
"""Module for solving differential algebraic systems and transient PDEs."""
# TODO: make sure, the integrator supports changing the step size if neccessarry! e.g. adams moulton: root iteration controler not possible...
# TODO: generate tests from examples
# TODO: translate all comments to english
# TODO: staggered policies, explicit diagonal modes
# TODO: add information about algebraic equations and add different treatements, e.g. projection for explicit modes
from typing import Any
from abc import ABC, abstractmethod
from dataclasses import dataclass
from math import isclose
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
import jax
import jax.numpy as jnp
from jax import tree_util, custom_jvp
from jax.experimental import sparse
from autopdex.utility import dict_flatten, reshape_as, jit_with_docstring
from autopdex import solver, assembler, implicit_diff
## helper functions
@custom_jvp
def _no_derivative(t, q):
return q
@_no_derivative.defjvp
def _no_derivative_jvp(primals, tangents):
raise ValueError("\n\nYour chosen integrator does not support derivatives up to the order you are using!\
Consider using a different integrator or convert your system to a system of lower order.\n")
[docs]@custom_jvp
def discrete_value_with_derivatives(t, q, q_derivs):
"""
Evaluate the discrete state value with custom derivative propagation.
This function returns the discrete state value `q`, but it is equipped with a custom Jacobian-vector product (JVP)
rule to correctly propagate derivative information through discrete operations. This custom derivative rule is designed
to support higher-order derivative calculations by processing a sequence of derivative values provided in `q_derivs`.
In the absence of derivative information (i.e. when `q_derivs` is empty), the derivative is taken to be `q_dot`.
When derivative information is available, the first derivative in `q_derivs` is used recursively along with the time
derivative `t_dot` to compute the overall derivative contribution.
It is used to construct a differentiable q_fun based on a value and its derivative defined by an integration rule, e.g.:
.. code-block:: python
def diffable_q_fun(t): # q_ts is a tuple of (q, q_t, q_tt, ...) coming from the integrator
return {key: discrete_value_with_derivatives(t, q_ts[key][0], q_ts[key][1:]) for key in template.keys()}
Args:
t: Scalar representing the time variable.
q: The discrete state value.
q_derivs: A sequence (e.g., list or tuple) of derivative values corresponding to `q`. The first element represents
the first derivative, with subsequent elements (if any) representing higher-order derivatives.
Returns:
The discrete state value `q`. The custom derivative rule ensures that during differentiation the returned derivative follows the form:
- If no derivative information is provided (`q_derivs` is empty): returns `q_dot`.
- Otherwise: returns `q_dot + (discrete_value_with_derivatives(t, first_deriv, remaining_derivs) * t_dot)`,
where `first_deriv` is the first element of `q_derivs` and `remaining_derivs` contains any higher-order derivatives.
"""
return q
@discrete_value_with_derivatives.defjvp
def discrete_value_with_derivatives_jvp(primals, tangents):
(t, q, q_derivs) = primals
(t_dot, q_dot, q_derivs_dot) = tangents
if len(q_derivs) == 0:
# return _no_derivative(t, q), q_dot # Problematic with multiple fields?!
return q, q_dot
else:
first_derivs = q_derivs[0]
remaining_derivs = q_derivs[1:]
return discrete_value_with_derivatives(
t, q, q_derivs), q_dot + discrete_value_with_derivatives(t, first_derivs, remaining_derivs) * t_dot
## butcher tableau inversion
[docs]def detect_stage_dependencies(A):
"""
Detects coupled structures (strongly connected components) in the Butcher matrix A and identifies explicit stages.
Parameters:
A (ndarray): Butcher matrix of stage coefficients (s x s).
Returns:
stage_blocks (list): A list of lists containing the indices of coupled stages.
explicit_stages (list): A list of indices corresponding to explicit stages.
block_dependencies (dict): A dictionary mapping each block to its dependent blocks.
"""
s = A.shape[0]
dependency_matrix = (A != 0).astype(int)
# Find SCCs
graph = csr_matrix(dependency_matrix)
n_components, labels = connected_components(csgraph=graph, directed=True, connection='strong')
# Group stages by their SCC labels
stage_blocks = [[] for _ in range(n_components)]
for i in range(s):
stage_blocks[labels[i]].append(i)
# Explicit stages: a_ii = 0
explicit_stages = []
for block in stage_blocks:
if all(A[i, i] == 0 for i in block): # Alle Diagonalelemente in diesem Block sind 0
explicit_stages.extend(block)
# Set up block dependencies
block_dependencies = {i: set() for i in range(n_components)}
for i in range(s):
for j in range(s):
if dependency_matrix[i, j]:
block_i = labels[i]
block_j = labels[j]
if block_i != block_j:
block_dependencies[block_i].add(block_j)
return stage_blocks, explicit_stages, block_dependencies
[docs]def invert_butcher_with_order(A):
"""
Computes the blockwise linear mapping matrix ``A_`` that maps U to U_dot without inter-block coupling,
and determines the execution order of the blocks.
Parameters:
``A`` (ndarray): Butcher matrix of stage coefficients (s x s).
Returns:
``A_`` (ndarray): Matrix mapping U to U_dot (s x s).
execution_order (list): ``A`` list specifying the order of operations (blocks or explicit stage indices).
"""
s = A.shape[0]
A_ = np.zeros_like(A) # Initialize resulting matrix
# Detect explicit stages and coupled blocks
stage_blocks, explicit_stages, block_dependencies = detect_stage_dependencies(A)
for block in stage_blocks:
if all(i in explicit_stages for i in block): # Explicit stages
continue
else:
# Coupled blocks
A_block = A[np.ix_(block, block)]
A_block_inv = np.linalg.inv(A_block)
for i, row_idx in enumerate(block):
for j, col_idx in enumerate(block):
A_[row_idx, col_idx] = A_block_inv[i, j]
# Explicit stages: set diagonal to 1 and invert the sign of the lower triangle
for i in explicit_stages:
for j in range(i):
A_[i, j] = -A[i, j]
A_[i, i] = 1
# Determine ordering of the blocks
execution_order = []
resolved = set()
def resolve_block(block_idx):
if block_idx in resolved:
return
for dep in block_dependencies[block_idx]:
resolve_block(dep)
resolved.add(block_idx)
block = stage_blocks[block_idx]
if all(i in explicit_stages for i in block):
for i in block:
execution_order.append((i, "explicit"))
else:
execution_order.append((block, "implicit"))
for block_idx in range(len(stage_blocks)):
resolve_block(block_idx)
return A_, execution_order
## integrator class
[docs]class TimeIntegrator(ABC):
"""
Base class for time integrators.
"""
[docs] def __init__(self,
name,
value_and_derivatives,
update,
stage_list,
stage_types,
stage_positions,
num_steps=1,
num_derivs=1,
num_stages=1):
"""
Initializes the time integrator.
Parameters:
name (str): The name of the method.
value_and_derivatives (callable): Function to compute state values and their derivatives.
update (callable): Function that updates the state based on stage results.
stage_list (ndarray): Array containing the indices or order of stages.
stage_types (tuple): Tuple indicating the type of each stage ('explicit' or 'implicit').
stage_positions (ndarray): Array of stage positions (e.g., Butcher nodes).
num_steps (int): Number of previous steps (for multi-step methods).
num_derivs (int): Highest derivative order that is supported.
num_stages (int): Number of stages (e.g., in Runge–Kutta methods).
"""
self.name = name
self.value_and_derivatives = value_and_derivatives
self.update = update
self.stage_list = stage_list
self.stage_types = stage_types
self.stage_positions = stage_positions
self.num_steps = num_steps
self.num_derivs = num_derivs
self.num_stages = num_stages
def _update(self, q_stages, q_n, q_t_n, dt):
"""
Updating rule after solving the step.
This method must be implemented by concrete integrator classes.
Parameters:
q_stages: Results from the solve, stage results
q_n: Values of last time steps
q_t_n: Derivatives of last time steps
dt: Step size
Returns:
A tuple containing the updated state and derivative.
"""
raise NotImplementedError
def _rule(self, q, q_n, q_t_n, dt):
"""
Computes the function value and temporal derivative for an integrator, e.g. q, (q-q_n[0])/dt for backward Euler.
Parameters:
q: Values of the stages that are to be determined.
q_n: Values of last time steps. q_n[0] is of time n, q_n[1] of time n-1, etc.
q_t_n: Derivatives at last time steps. q_t_n[i, j] is the j+1-th derivative of time n-i.
dt: Time step size.
Returns:
State values and derivatives for the stages or for the next time step.
"""
raise NotImplementedError
def _error_estimate(self, q, q_n, q_t_n, dt):
"""
Computes an error estimate for the integrator. Similar to _update. Check e.g. Kvaerno.
"""
return None
## specific integrator classes
[docs]class BackwardEuler(TimeIntegrator):
"""
Backward Euler method.
Accuracy: 1st order.
Stability: L-stable.
Number of steps: 1.
Number of stages: 1, implicit.
Number of derivatives: 1.
"""
[docs] def __init__(self):
super().__init__("backward_euler",
self._rule,
self._update,
jnp.asarray([[0]]), ('implicit',),
jnp.array([1.]),
num_steps=1,
num_derivs=1,
num_stages=1)
self.butcher_b = jnp.array([1.])
self.order = 1
def _update(self, q_stages, q_n, q_t_n, dt):
q_n1 = q_stages[0]
q_t_n1 = self.value_and_derivatives(q_n1, q_n, q_t_n, dt)[1:]
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt):
q_t = (q - q_n[0]) / dt
return q, q_t
[docs]class ForwardEuler(TimeIntegrator):
"""
Forward Euler method.
Accuracy: 1st order.
Stability: instable for stiff problems.
Number of steps: 1.
Number of stages: 1, explicit.
Number of derivatives: 1.
"""
[docs] def __init__(self):
super().__init__("forward_euler",
self._rule,
self._update,
jnp.asarray([[0]]), ('explicit',),
jnp.array([0.]),
num_steps=1,
num_derivs=1,
num_stages=1)
self.butcher_b = jnp.array([1.])
self.order = 1
def _update(self, q_stages, q_n, q_t_n, dt):
q_n1 = q_stages[0]
q_t_n1 = self.value_and_derivatives(q_n1, q_n, q_t_n, dt)[1:]
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt):
q_t = (q - q_n[0]) / dt
return jnp.asarray([q_n[0]]), q_t
[docs]class Newmark(TimeIntegrator):
"""Newmark-beta method.
Args:
gamma (float): Newmark parameter.
beta (float): Newmark parameter.
Explicit central differences:
gamma = 0.5
beta = 0
Average constant acceleration (middle point rule, unconditional stable):
gamma = 0.5
beta = 0.25
Number of steps: 1.
Number of stages: 1, explicit or implicit.
Number of derivatives: 2.
"""
[docs] def __init__(self, gamma=0.5, beta=0.25):
if isclose(beta, 0.):
super().__init__("newmark",
self._rule,
self._update,
jnp.asarray([[0]]), ('explicit',),
jnp.array([0.]),
num_steps=1,
num_derivs=2,
num_stages=1)
else:
super().__init__("newmark",
self._rule,
self._update,
jnp.asarray([[0]]), ('implicit',),
jnp.array([1.]),
num_steps=1,
num_derivs=2,
num_stages=1)
self.gamma = gamma
self.beta = beta
self.butcher_b = jnp.array([1.])
self.order = 2
def _update(self, q_stages, q_n, q_t_n, dt):
q_n1 = q_stages[0]
q_t_n1 = self.value_and_derivatives(q_n1, q_n, q_t_n, dt)[1:]
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt):
gamma = self.gamma
beta = self.beta
dq = q - q_n[0]
if isclose(beta, 0.):
# Central differences
q_t = (q - q_n[0]) / dt
q_tt = (q - q_n[0] - q_t_n[0, 0] * dt) / (dt**2)
return jnp.asarray([q_n[0]]), q_t, q_tt
else:
v_n = q_t_n[0, 0]
a_n = q_t_n[0, 1]
q_tt = (dq / dt**2 - v_n / dt - a_n * (1 / 2 - beta)) / beta
q_t = v_n + dt * ((1 - gamma) * a_n + gamma * q_tt)
return q, q_t, q_tt
[docs]class AdamsMoulton(TimeIntegrator):
"""Adams-Moulton method.
Args:
num_steps (int): Number of previous steps (1 to 6).
Number of stages: 1, implicit.
"""
[docs] def __init__(self, num_steps):
super().__init__("adams_moulton",
self._rule,
self._update,
jnp.asarray([[0]]), ('implicit',),
jnp.array([1.]),
num_steps=num_steps,
num_derivs=1,
num_stages=1)
self.num_steps = num_steps
self.butcher_b = jnp.array([1.])
self.order = num_steps + 1
def _update(self, q_stages, q_n, q_t_n, dt):
q_n1 = q_stages[0]
q_t_n1 = self.value_and_derivatives(q_n1, q_n, q_t_n, dt)[1:]
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt):
num_steps = self.num_steps
# Adams-Moulton coefficients
adams_moulton_coeffs = {
1:
jnp.array([1 / 2, 1 / 2]),
2:
jnp.array([5 / 12, 8 / 12, -1 / 12]),
3:
jnp.array([9 / 24, 19 / 24, -5 / 24, 1 / 24]),
4:
jnp.array([251 / 720, 646 / 720, -264 / 720, 106 / 720, -19 / 720]),
5:
jnp.array([475 / 1440, 1427 / 1440, -798 / 1440, 482 / 1440, -173 / 1440, 27 / 1440]),
6:
jnp.array([
19087 / 60480, 65112 / 60480, -46461 / 60480, 37504 / 60480, -20211 / 60480, 6312 / 60480, -863 / 60480
]),
}
# Ensure num_steps is supported
if num_steps not in adams_moulton_coeffs:
raise ValueError(f"num_steps={num_steps} is not supported. Supported: {list(adams_moulton_coeffs.keys())}")
# Get the coefficients for the specified num_steps
coeffs = adams_moulton_coeffs[num_steps]
a_0 = coeffs[0]
# Compute q_t (implicit derivative) using the Adams-Moulton formula
q_t = (q - q_n[0]) / dt # Start with the difference quotient
q_t = q_t - jnp.einsum("j,j...->...", coeffs[1:], q_t_n[:, 0]) # Subtract weighted previous derivatives
q_t /= a_0 # Divide by a_0 to solve for q_t
return q, q_t
[docs]class AdamsBashforth(TimeIntegrator):
"""Adams-Bashforth time integrator.
Args:
num_steps (int): Number of previous steps (1 to 6).
Number of stages: 1, explicit.
"""
[docs] def __init__(self, num_steps):
super().__init__("adams_bashforth",
self._rule,
self._update,
jnp.asarray([[0]]), ('explicit',),
jnp.array([0.]),
num_steps=num_steps,
num_derivs=1,
num_stages=1)
self.num_steps = num_steps
self.butcher_b = jnp.array([1.])
self.order = num_steps
def _update(self, q_stages, q_n, q_t_n, dt):
q_n1 = q_stages[0]
q_t_n1 = self.value_and_derivatives(q_n1, q_n, q_t_n, dt)[1:]
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt):
num_steps = self.num_steps
# Adams-Bashforth coefficients for different step numbers
adams_bashforth_coeffs = {
0: jnp.array([1]),
1: jnp.array([3 / 2, -1 / 2]),
2: jnp.array([23 / 12, -16 / 12, 5 / 12]),
3: jnp.array([55 / 24, -59 / 24, 37 / 24, -9 / 24]),
4: jnp.array([1901 / 720, -2774 / 720, 2616 / 720, -1274 / 720, 251 / 720]),
5: jnp.array([4277 / 1440, -7923 / 1440, 9982 / 1440, -7298 / 1440, 2877 / 1440, -475 / 1440]),
}
# Ensure num_steps is supported
if num_steps - 1 not in adams_bashforth_coeffs:
raise ValueError(f"num_steps={num_steps} is not supported. Supported number of steps: 1 to 6.")
# Get the coefficients for the specified num_steps
coeffs = adams_bashforth_coeffs[num_steps - 1]
a_0 = coeffs[0]
# Compute q_t for the previous step using the Adams-Bashforth formula
q_t = (q - q_n[0]) / dt # Start with the difference quotient
q_t = q_t - jnp.einsum("j,j...->...", coeffs[1:], q_t_n[1:, 0]) # Subtract weighted previous derivatives
q_t /= a_0 # Divide by a_0
return jnp.asarray([q_n[0]]), q_t
[docs]class ExplicitRungeKutta(TimeIntegrator):
"""Explicit Runge-Kutta method.
Args:
num_stages (int): Number of stages (1, 2, 3, 4, 5, 6, 7, 9, 11).
"""
[docs] def __init__(self, num_stages):
match num_stages: # From JC Butcher 2008: Numerical Methods for Ordinary Differential Equations, ISBN: 978-0-470-72335-7
case 1: # Forward Euler
butcher_c = jnp.array([0])
butcher_b = jnp.array([1])
butcher_A = jnp.array([[0]])
self.order = 1
case 2: # Heun's method
butcher_c = jnp.array([0, 1])
butcher_b = jnp.array([1 / 2, 1 / 2])
butcher_A = jnp.array([[0, 0], [1, 0]])
self.order = 2
case 3: # Kutta's third-order method
butcher_c = jnp.array([0, 1 / 2, 1])
butcher_b = jnp.array([1 / 6, 2 / 3, 1 / 6])
butcher_A = jnp.array([[0, 0, 0], [1 / 2, 0, 0], [-1, 2, 0]])
self.order = 3
case 4: # Classic Runge-Kutta method
butcher_c = jnp.array([0, 1 / 2, 1 / 2, 1])
butcher_b = jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6])
butcher_A = jnp.array([[0, 0, 0, 0], [1 / 2, 0, 0, 0], [0, 1 / 2, 0, 0], [0, 0, 1, 0]])
self.order = 4
case 6:
butcher_c = jnp.array([0, 1 / 4, 1 / 4, 1 / 2, 3 / 4, 1])
butcher_b = jnp.array([7 / 90, 0, 16 / 45, 2 / 15, 16 / 45, 7 / 90])
butcher_A = jnp.array([[0, 0, 0, 0, 0, 0], [1 / 4, 0, 0, 0, 0, 0], [1 / 8, 1 / 8, 0, 0, 0, 0],
[0, 0, 1 / 2, 0, 0, 0], [3 / 16, -3 / 8, 3 / 8, 9 / 16, 0, 0],
[-3 / 7, 8 / 7, 6 / 7, -12 / 7, 8 / 7, 0]])
self.order = 5
case 7:
butcher_c = jnp.array([0, 1 / 3, 2 / 3, 1 / 3, 5 / 6, 1 / 6, 1])
butcher_b = jnp.array([13 / 200, 0, 11 / 40, 11 / 40, 4 / 25, 4 / 25, 13 / 200])
butcher_A = jnp.array([[0, 0, 0, 0, 0, 0, 0], [1 / 3, 0, 0, 0, 0, 0, 0], [0, 2 / 3, 0, 0, 0, 0, 0],
[1 / 12, 1 / 3, -1 / 12, 0, 0, 0, 0], [25 / 48, -55 / 24, 35 / 48, 15 / 8, 0, 0, 0],
[3 / 20, -11 / 24, -1 / 8, 1 / 2, 1 / 10, 0, 0],
[-261 / 260, 33 / 13, 43 / 156, -118 / 39, 32 / 195, 80 / 39, 0]])
self.order = 6
case 9:
butcher_c = jnp.array([0, 1 / 6, 1 / 3, 1 / 2, 2 / 11, 2 / 3, 6 / 7, 0, 1])
butcher_b = jnp.array([0, 0, 0, 32 / 105, 1771561 / 6289920, 243 / 2560, 16807 / 74880, 77 / 1440, 11 / 270])
butcher_A = jnp.array([[0, 0, 0, 0, 0, 0, 0, 0, 0], [1 / 6, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1 / 3, 0, 0, 0, 0, 0, 0, 0], [1 / 8, 0, 3 / 8, 0, 0, 0, 0, 0, 0],
[148 / 1331, 0, 150 / 1331, -56 / 1331, 0, 0, 0, 0, 0],
[-404 / 243, 0, -170 / 27, 4024 / 1701, 10648 / 1701, 0, 0, 0, 0],
[2466 / 2401, 0, 1242 / 343, -19176 / 16807, -51909 / 16807, 1053 / 2401, 0, 0, 0],
[5 / 154, 0, 0, 96 / 539, -1815 / 20384, -405 / 2464, 49 / 1144, 0, 0],
[-113 / 32, 0, -195 / 22, 32 / 7, 29403 / 3584, -729 / 512, 1029 / 1408, 21 / 16, 0]])
self.order = 7
case 11:
sqrt_21 = jnp.sqrt(21)
butcher_c = jnp.array([
0, 1 / 2, 1 / 2, (7 + sqrt_21) / 14, (7 + sqrt_21) / 14, 1 / 2, (7 - sqrt_21) / 14, (7 - sqrt_21) / 14,
1 / 2, (7 + sqrt_21) / 14, 1
])
butcher_b = jnp.array([1 / 20, 0, 0, 0, 0, 0, 0, 49 / 180, 16 / 45, 49 / 180, 1 / 20])
butcher_A = jnp.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1 / 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1 / 4, 1 / 4, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1 / 7, (-7 - 3 * sqrt_21) / 98, (21 + 5 * sqrt_21) / 49, 0, 0, 0, 0, 0, 0, 0, 0],
[(11 + sqrt_21) / 84, 0, (18 + 4 * sqrt_21) / 63, (21 - sqrt_21) / 252, 0, 0, 0, 0, 0, 0,
0],
[(5 + sqrt_21) / 48, 0, (9 + sqrt_21) / 36, (-231 + 14 * sqrt_21) / 360,
(63 - 7 * sqrt_21) / 80, 0, 0, 0, 0, 0, 0],
[(10 - sqrt_21) / 42, 0, (-432 + 92 * sqrt_21) / 315, (633 - 145 * sqrt_21) / 90,
(-504 + 115 * sqrt_21) / 70, (63 - 13 * sqrt_21) / 35, 0, 0, 0, 0, 0],
[1 / 14, 0, 0, 0, (14 - 3 * sqrt_21) / 126, (13 - 3 * sqrt_21) / 63, 1 / 9, 0, 0, 0, 0],
[
1 / 32, 0, 0, 0, (91 - 21 * sqrt_21) / 576, 11 / 72, (-385 - 75 * sqrt_21) / 1152,
(63 + 13 * sqrt_21) / 128, 0, 0, 0
],
[
1 / 14, 0, 0, 0, 1 / 9, (-733 - 147 * sqrt_21) / 2205, (515 + 111 * sqrt_21) / 504,
(-51 - 11 * sqrt_21) / 56, (132 + 28 * sqrt_21) / 245, 0, 0
],
[
0, 0, 0, 0, (-42 + 7 * sqrt_21) / 18, (-18 + 28 * sqrt_21) / 45,
(-273 - 53 * sqrt_21) / 72, (301 + 53 * sqrt_21) / 72, (28 - 28 * sqrt_21) / 45,
(49 - 7 * sqrt_21) / 18, 0
]])
self.order = 8
case _:
raise ValueError("num_stages not supported for ExplicitRungeKutta. Supported: 1, 2, 3, 4, 6, 7, 9, 11")
stages = jnp.asarray([[i] for i in range(num_stages)])
stage_types = tuple('explicit' for i in range(num_stages))
super().__init__("explicit_runge_kutta",
self._rule,
self._update,
stages,
stage_types,
butcher_c,
num_steps=1,
num_derivs=1,
num_stages=num_stages)
self.num_stages = num_stages
self.butcher_A_inv = jnp.asarray(invert_butcher_with_order(butcher_A)[0])
self.butcher_A = butcher_A
self.butcher_b = butcher_b
self.butcher_c = butcher_c
def _update(self, q_stages, q_n, q_t_n, dt):
# Linear combination of stage results
q_s_t = self.value_and_derivatives(q_stages, q_n, q_t_n, dt)[1]
q_n1 = q_n[0] + dt * jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
q_t_n1 = jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt):
q_t = (1 / dt) * jnp.einsum("ij,j...->i...", self.butcher_A_inv, (q - jnp.stack([q_n[0]] * self.num_stages)))
q = jnp.stack([q_n[0]] * self.num_stages) + dt * jnp.einsum("ij,j...->i...", self.butcher_A, q_t)
return q, q_t
[docs]class DiagonallyImplicitRungeKutta(TimeIntegrator):
"""Diagonally implicit Runge-Kutta method.
Args:
num_stages (int): Number of stages (1, 2, 3).
"""
[docs] def __init__(self, num_stages):
match num_stages: # From JC Butcher 2008: Numerical Methods for Ordinary Differential Equations, ISBN: 978-0-470-72335-7
case 1: # Implicit midpoint (GauĂź-Legendre, 2nd order, symplectic)
butcher_c = jnp.array([1 / 2])
butcher_b = jnp.array([1])
butcher_A = jnp.array([[1 / 2]])
self.order = 2
case 2: # Crouzeix's method (3rd order)
sqrt_3 = jnp.sqrt(3)
one_half = 1 / 2
butcher_c = jnp.array([one_half + sqrt_3 / 6, one_half - sqrt_3 / 6])
butcher_b = jnp.array([one_half, one_half])
butcher_A = jnp.array([[one_half + sqrt_3 / 6, 0], [-sqrt_3 / 3, one_half + sqrt_3 / 6]])
self.order = 3
case 3: # Crouzeix's method (4rd order)
alpha = 2 * jnp.cos(jnp.pi / 18) / jnp.sqrt(3)
butcher_c = jnp.array([(1 + alpha) / 2, 1 / 2, (1 - alpha) / 2])
butcher_b = jnp.array([1 / (6 * alpha**2), 1 - 1 / (3 * alpha**2), 1 / (6 * alpha**2)])
butcher_A = jnp.array([[(1 + alpha) / 2, 0, 0], [-alpha / 2, (1 + alpha) / 2, 0],
[1 + alpha, -(1 + 2 * alpha), (1 + alpha) / 2]])
self.order = 4
case _:
raise ValueError("num_stages not supported for ExplicitRungeKutta. Supported: 1, 2, 3")
stages = jnp.asarray([[i] for i in range(num_stages)])
stage_types = tuple('implicit' for i in range(num_stages))
super().__init__("diagonally_implicit_runge_kutta",
self._rule,
self._update,
stages,
stage_types,
butcher_c,
num_steps=1,
num_derivs=1,
num_stages=num_stages)
self.num_stages = num_stages
self.butcher_A_inv = jnp.asarray(invert_butcher_with_order(butcher_A)[0])
self.butcher_A = butcher_A
self.butcher_b = butcher_b
self.butcher_c = butcher_c
def _update(self, q_stages, q_n, q_t_n, dt):
# Linear combination of stage results
q_s_t = self.value_and_derivatives(q_stages, q_n, q_t_n, dt)[1]
q_n1 = q_n[0] + dt * jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
q_t_n1 = jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt): # q contains all stages (first dimension)
q_t = (1 / dt) * jnp.einsum("ij,j...->i...", self.butcher_A_inv, (q - jnp.stack([q_n[0]] * self.num_stages)))
q = jnp.stack([q_n[0]] * self.num_stages) + dt * jnp.einsum("ij,j...->i...", self.butcher_A, q_t)
return q, q_t
[docs]class Kvaerno(TimeIntegrator):
"""Kvaerno method (explicit first stage diagonally implicit Runge-Kutta with embedded error estimation).
Args:
order (int): Order of the method (3, 4, 5).
Supports PID control.
"""
[docs] def __init__(self, order):
match order: # Coefficients from https://github.com/patrick-kidger/diffrax/blob/0a59c9dbd34f580efb3505386f38ce9fcedb120b/diffrax/_solver -> kvaerno{3,4,5}.py
case 3:
Îł = 0.43586652150
a21 = Îł
a31 = (-4 * Îł**2 + 6 * Îł - 1) / (4 * Îł)
a32 = (-2 * Îł + 1) / (4 * Îł)
a41 = (6 * Îł - 1) / (12 * Îł)
a42 = -1 / ((24 * Îł - 12) * Îł)
a43 = (-6 * Îł**2 + 6 * Îł - 1) / (6 * Îł - 3)
butcher_c = jnp.array([0., 2 * Îł, 1.0, 1.0])
butcher_b = jnp.array([a41, a42, a43, Îł])
error_b = jnp.array([a41 - a31, a42 - a32, a43 - Îł, Îł])
butcher_A = jnp.array([[0, 0, 0, 0], [a21, Îł, 0, 0], [a31, a32, Îł, 0], [a41, a42, a43, Îł]])
self.order = 3
case 4:
Îł = 0.5728160625
def poly(*args):
return jnp.polyval(jnp.asarray(args), Îł)
a21 = Îł
a31 = poly(144, -180, 81, -15, 1) * Îł / poly(12, -6, 1)**2
a32 = poly(-36, 39, -15, 2) * Îł / poly(12, -6, 1)**2
a41 = poly(-144, 396, -330, 117, -18, 1) / (12 * Îł**2 * poly(12, -9, 2))
a42 = poly(72, -126, 69, -15, 1) / (12 * Îł**2 * poly(3, -1))
a43 = (poly(-6, 6, -1) * poly(12, -6, 1)**2) / (12 * Îł**2 * poly(12, -9, 2) * poly(3, -1))
a51 = poly(288, -312, 120, -18, 1) / (48 * Îł**2 * poly(12, -9, 2))
a52 = poly(24, -12, 1) / (48 * Îł**2 * poly(3, -1))
a53 = -(poly(12, -6, 1)**3) / (48 * Îł**2 * poly(3, -1) * poly(12, -9, 2) * poly(6, -6, 1))
a54 = poly(-24, 36, -12, 1) / poly(24, -24, 4)
c2 = Îł + a21
c3 = Îł + a31 + a32
c4 = 1.0
c5 = 1.0
butcher_c = jnp.array([0, c2, c3, c4, c5])
butcher_b = jnp.array([a51, a52, a53, a54, Îł])
error_b = jnp.array([a51 - a41, a52 - a42, a53 - a43, a54 - Îł, Îł])
butcher_A = jnp.array([[0, 0, 0, 0, 0], [a21, Îł, 0, 0, 0], [a31, a32, Îł, 0, 0], [a41, a42, a43, Îł, 0],
[a51, a52, a53, a54, Îł]])
self.order = 4
case 5:
Îł = 0.26
a21 = Îł
a31 = 0.13
a32 = 0.84033320996790809
a41 = 0.22371961478320505
a42 = 0.47675532319799699
a43 = -0.06470895363112615
a51 = 0.16648564323248321
a52 = 0.10450018841591720
a53 = 0.03631482272098715
a54 = -0.13090704451073998
a61 = 0.13855640231268224
a62 = 0
a63 = -0.04245337201752043
a64 = 0.02446657898003141
a65 = 0.61943039072480676
a71 = 0.13659751177640291
a72 = 0
a73 = -0.05496908796538376
a74 = -0.04118626728321046
a75 = 0.62993304899016403
a76 = 0.06962479448202728
butcher_c = jnp.array([0, 0.52, 1.230333209967908, 0.8957659843500759, 0.43639360985864756, 1.0, 1.0])
butcher_b = jnp.array([a71, a72, a73, a74, a75, a76, Îł])
error_b = jnp.array([a71 - a61, a72 - a62, a73 - a63, a74 - a64, a75 - a65, a76 - Îł, Îł])
butcher_A = jnp.array([[0, 0, 0, 0, 0, 0, 0], [a21, Îł, 0, 0, 0, 0, 0], [a31, a32, Îł, 0, 0, 0, 0],
[a41, a42, a43, Îł, 0, 0, 0], [a51, a52, a53, a54, Îł, 0, 0],
[a61, a62, a63, a64, a65, Îł, 0], [a71, a72, a73, a74, a75, a76, Îł]])
self.order = 5
case _:
raise ValueError("order not supported for Kvaerno. Supported: 3, 4, 5")
num_stages = butcher_c.shape[0]
stages = jnp.asarray([[i] for i in range(num_stages)])
stage_types = ('explicit', *tuple('implicit' for i in range(num_stages - 1)))
super().__init__("Kvaerno",
self._rule,
self._update,
stages,
stage_types,
butcher_c,
num_steps=1,
num_derivs=1,
num_stages=num_stages)
self.num_stages = num_stages
self.butcher_A_inv = jnp.asarray(invert_butcher_with_order(butcher_A)[0])
self.butcher_A = butcher_A
self.butcher_b = butcher_b
self.butcher_c = butcher_c
self.error_b = error_b
def _update(self, q_stages, q_n, q_t_n, dt):
# Linear combination of stage results
q_s_t = self.value_and_derivatives(q_stages, q_n, q_t_n, dt)[1]
q_n1 = q_n[0] + dt * jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
q_t_n1 = jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
return q_n1, q_t_n1
def _error_estimate(self, q_stages, q_n, q_t_n, dt):
# Linear combination of stage results
q_s_t = self.value_and_derivatives(q_stages, q_n, q_t_n, dt)[1]
weights = self.error_b
error_estimate = dt * jnp.einsum("j,j...->...", weights, q_s_t)
return error_estimate
def _rule(self, q, q_n, q_t_n, dt): # q contains all stages (first dimension)
q_t = (1 / dt) * jnp.einsum("ij,j...->i...", self.butcher_A_inv, (q - jnp.stack([q_n[0]] * self.num_stages)))
q = jnp.stack([q_n[0]] * self.num_stages) + dt * jnp.einsum("ij,j...->i...", self.butcher_A, q_t)
return q, q_t
[docs]class DormandPrince(TimeIntegrator):
"""Dormand-Prince method (explicit with embedded error estimation).
Args:
order (int): Order of the method (5, 8).
Supports PID control.
"""
[docs] def __init__(self, order):
# Coefficients from https://github.com/patrick-kidger/diffrax/blob/0a59c9dbd34f580efb3505386f38ce9fcedb120b/diffrax/_solver -> dopri{5,8}.py
match order:
case 5:
butcher_c = jnp.array([0, 1 / 5, 3 / 10, 4 / 5, 8 / 9, 1.0, 1.0])
butcher_b = jnp.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
error_b = jnp.array([
35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085, 125 / 192 - 451 / 720,
-2187 / 6784 - -12231 / 42400, 11 / 84 - 649 / 6300, -1.0 / 60.0
])
butcher_A = jnp.array([[0, 0, 0, 0, 0, 0, 0], [1 / 5, 0, 0, 0, 0, 0, 0], [3 / 40, 9 / 40, 0, 0, 0, 0, 0],
[44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
[19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
[9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]])
self.order = 5
case 8:
butcher_c = jnp.array([
0, 1 / 18, 1 / 12, 1 / 8, 5 / 16, 3 / 8, 59 / 400, 93 / 200, 5490023248 / 9719169821, 13 / 20,
1201146811 / 1299019798, 1, 1, 1
])
butcher_b = jnp.array([
14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731, 561292985 / 797845732,
-1041891430 / 1371343529, 760417239 / 1151165299, 118820643 / 751138087, -528747749 / 2220607170, 1 / 4, 0
])
error_b = jnp.array([
14005451 / 335480064 - 13451932 / 455176623, 0, 0, 0, 0, -59238493 / 1068277825 - -808719846 / 976000145,
181606767 / 758867731 - 1757004468 / 5645159321, 561292985 / 797845732 - 656045339 / 265891186,
-1041891430 / 1371343529 - -3867574721 / 1518517206, 760417239 / 1151165299 - 465885868 / 322736535,
118820643 / 751138087 - 53011238 / 667516719, -528747749 / 2220607170 - 2 / 45, 1 / 4, 0
])
butcher_A = jnp.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1 / 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1 / 48, 1 / 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1 / 32, 0, 3 / 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5 / 16, 0, -75 / 64, 75 / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[3 / 80, 0, 0, 3 / 16, 3 / 20, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[
29443841 / 614563906, 0, 0, 77736538 / 692538347, -28693883 / 1125000000,
23124283 / 1800000000, 0, 0, 0, 0, 0, 0, 0, 0
],
[
16016141 / 946692911, 0, 0, 61564180 / 158732637, 22789713 / 633445777,
545815736 / 2771057229, -180193667 / 1043307555, 0, 0, 0, 0, 0, 0, 0
],
[
39632708 / 573591083, 0, 0, -433636366 / 683701615, -421739975 / 2616292301,
100302831 / 723423059, 790204164 / 839813087, 800635310 / 3783071287, 0, 0, 0, 0, 0,
0
],
[
246121993 / 1340847787, 0, 0, -37695042795 / 15268766246, -309121744 / 1061227803,
-12992083 / 490766935, 6005943493 / 2108947869, 393006217 / 1396673457,
123872331 / 1001029789, 0, 0, 0, 0, 0
],
[
-1028468189 / 846180014, 0, 0, 8478235783 / 508512852, 1311729495 / 1432422823,
-10304129995 / 1701304382, -48777925059 / 3047939560, 15336726248 / 1032824649,
-45442868181 / 3398467696, 3065993473 / 597172653, 0, 0, 0, 0
],
[
185892177 / 718116043, 0, 0, -3185094517 / 667107341, -477755414 / 1098053517,
-703635378 / 230739211, 5731566787 / 1027545527, 5232866602 / 850066563,
-4093664535 / 808688257, 3962137247 / 1805957418, 65686358 / 487910083, 0, 0, 0
],
[
403863854 / 491063109, 0, 0, -5068492393 / 434740067, -411421997 / 543043805,
652783627 / 914296604, 11173962825 / 925320556, -13158990841 / 6184727034,
3936647629 / 1978049680, -160528059 / 685178525, 248638103 / 1413531060, 0, 0, 0
],
[
14005451 / 335480064, 0, 0, 0, 0, -59238493 / 1068277825, 181606767 / 758867731,
561292985 / 797845732, -1041891430 / 1371343529, 760417239 / 1151165299,
118820643 / 751138087, -528747749 / 2220607170, 1 / 4, 0
]])
self.order = 8
case _:
raise ValueError("order not supported for DormandPrince. Supported: 5, 8")
num_stages = butcher_c.shape[0]
stages = jnp.asarray([[i] for i in range(num_stages)])
stage_types = tuple('explicit' for i in range(num_stages))
super().__init__("DormandPrince",
self._rule,
self._update,
stages,
stage_types,
butcher_c,
num_steps=1,
num_derivs=1,
num_stages=num_stages)
self.num_stages = num_stages
self.butcher_A_inv = jnp.asarray(invert_butcher_with_order(butcher_A)[0])
self.butcher_A = butcher_A
self.butcher_b = butcher_b
self.butcher_c = butcher_c
self.error_b = error_b
def _update(self, q_stages, q_n, q_t_n, dt):
# Linear combination of stage results
q_s_t = self.value_and_derivatives(q_stages, q_n, q_t_n, dt)[1]
q_n1 = q_n[0] + dt * jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
q_t_n1 = jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
return q_n1, q_t_n1
def _error_estimate(self, q_stages, q_n, q_t_n, dt):
# Linear combination of stage results
q_s_t = self.value_and_derivatives(q_stages, q_n, q_t_n, dt)[1]
weights = self.error_b
error_estimate = dt * jnp.einsum("j,j...->...", weights, q_s_t)
return error_estimate
def _rule(self, q, q_n, q_t_n, dt): # q contains all stages (first dimension)
q_t = (1 / dt) * jnp.einsum("ij,j...->i...", self.butcher_A_inv, (q - jnp.stack([q_n[0]] * self.num_stages)))
q = jnp.stack([q_n[0]] * self.num_stages) + dt * jnp.einsum("ij,j...->i...", self.butcher_A, q_t)
return q, q_t
[docs]class GaussLegendreRungeKutta(TimeIntegrator):
"""Gauss-Legendre Runge-Kutta method (fully implicit).
Args:
num_stages (int): Number of stages.
Accuracy: 2 * num_stages.
"""
[docs] def __init__(self, num_stages):
def get_gauss_legendre_coefficients(s):
# from numpy.polynomial.legendre import leggauss
# 1. Berechne die GauĂź-Legendre-Knoten und -Gewichte auf [-1, 1]
nodes, weights = np.polynomial.legendre.leggauss(s)
# 2. Skaliere die Knoten und Gewichte auf [0, 1]
c = 0.5 * (nodes + 1)
b = 0.5 * weights
# 3. Berechne die Matrix A
A = np.zeros((s, s))
for i in range(s):
for j in range(s):
# Erstelle die Lagrange-Basisfunktion L_j(x)
L_j = np.poly1d([1.0])
for k in range(s):
if k != j:
L_j = np.poly1d(np.convolve(L_j.coeffs, [1.0, -c[k]])) / (c[j] - c[k])
# Integriere L_j(x) von 0 bis c_i
Lj_int = np.polyint(L_j)
A[i, j] = Lj_int(c[i]) - Lj_int(0.0)
# 4. Konvertiere A, c und b zu JAX-Arrays
A = jnp.array(A)
c = jnp.array(c)
b = jnp.array(b)
return A, b, c
butcher_A, butcher_b, butcher_c = get_gauss_legendre_coefficients(num_stages)
butcher_A_inv, stage_order = invert_butcher_with_order(butcher_A)
stage_types = tuple(st[1] for st in stage_order)
stage_order = jnp.asarray([st[0] for st in stage_order])
super().__init__("GaussLegendreRungeKutta",
self._rule,
self._update,
stage_order,
stage_types,
butcher_c,
num_steps=1,
num_derivs=1,
num_stages=num_stages)
self.num_stages = num_stages
self.butcher_A_inv = jnp.asarray(butcher_A_inv)
self.butcher_A = butcher_A
self.butcher_b = butcher_b
self.butcher_c = butcher_c
self.order = 2 * num_stages
def _update(self, q_stages, q_n, q_t_n, dt):
# Linear combination of stage results
q_s_t = self.value_and_derivatives(q_stages, q_n, q_t_n, dt)[1]
q_n1 = q_n[0] + dt * jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
q_t_n1 = jnp.einsum("j,j...->...", self.butcher_b, q_s_t)
return q_n1, q_t_n1
def _rule(self, q, q_n, q_t_n, dt): # q contains all stages (first dimension)
q_t = (1 / dt) * jnp.einsum("ij,j...->i...", self.butcher_A_inv, (q - jnp.stack([q_n[0]] * self.num_stages)))
q = jnp.stack([q_n[0]] * self.num_stages) + dt * jnp.einsum("ij,j...->i...", self.butcher_A, q_t)
return q, q_t
## Saving policies
[docs]@jax.tree_util.register_dataclass
@dataclass
class HistoryState:
"""
Container for storing the history state data.
Attributes:
t: Dictionary of time data arrays.
q: Dictionary of state variable arrays.
user: Dictionary of additional user data.
"""
t: dict[str, jnp.ndarray]
q: dict[str, jnp.ndarray]
user: Any
[docs]class SavePolicy(ABC):
"""Abstract base class for save strategies."""
@abstractmethod
def initialize(self, q, t_max, num_time_steps, user_data={}):
"""
Initializes the storage.
Args:
q: Dictionary of state variables.
t_max: Maximum simulation time.
num_time_steps: Number of time steps.
user_data: Dictionary of additional user data.
Returns:
An initial state for the saving strategy.
"""
pass
@abstractmethod
def save_step(self, state, t, q, user_data={}):
"""
Saves the desired data to the storage.
Args:
state: The current history state.
t: The current time.
q: The current state dictionary.
user_data: Dictionary of additional user data.
Returns:
The updated history state.
"""
pass
@abstractmethod
def finalize(self, state):
"""
Finalizes the storage and returns the relevant history data.
Args:
state: The current history state.
Returns:
The finalized history data.
"""
pass
[docs]class SaveNothingPolicy(SavePolicy):
"""A policy that does not save any data."""
def initialize(self, q, t_max, num_time_steps, user_data={}):
return None
def save_step(self, state, t, q, user_data={}):
return None
def finalize(self, state):
return None
[docs]@jax.tree_util.register_dataclass
@dataclass
class SaveEquidistantHistoryState:
"""
State for the SaveAllPolicy.
"""
t_max: float
num_points: int
target_times: jnp.ndarray
t: dict[str, jnp.ndarray]
q: dict[str, jnp.ndarray]
current_save_idx: int
user: Any
[docs]class SaveEquidistantPolicy(SavePolicy):
"""
Saves data at (approximately) equidistant time points using pre-allocated arrays.
"""
[docs] def __init__(self, num_points=None, tol=1e-6):
self.tol = tol
self.num_points = num_points
def initialize(self, q, t_max, max_steps, user_data={}):
"""
Initialisiert die vorab allokierten Arrays für Zeit und Zustände.
Args:
q_keys: SchlĂĽssel der Zustandsvariablen.
q_shapes: Formen der Zustandsvariablen.
Returns:
Ein Tupel aus:
- history_t: Array fĂĽr die Zeitdaten.
- history_q: Dictionary von Arrays fĂĽr die Zustandsdaten.
- target_times: Array der Zielzeitpunkte.
- current_save_idx: Initialer Index fĂĽr das Speichern.
"""
num_points = self.num_points if self.num_points is not None else max_steps
history_t = jnp.full(num_points + 1, jnp.nan)
history_q = {key: jnp.full((num_points + 1,) + q[key].shape, jnp.nan) for key in q.keys()}
history_user = {key: jnp.full((num_points + 1,) + user_data[key].shape, jnp.nan) for key in user_data.keys()}
target_times = jnp.linspace(0, t_max, num_points + 1)
current_save_idx = 0
return SaveEquidistantHistoryState(t_max,
num_points,
target_times,
history_t,
history_q,
current_save_idx,
user=history_user)
def save_step(self, state, t, q, user_data={}):
"""
Speichert den aktuellen Zustand, wenn der Zielzeitpunkt erreicht ist.
Args:
state: Ein Tupel aus (history_t, history_q, target_times, current_save_idx).
t: Aktuelle Zeit.
q: Aktueller Zustand.
Returns:
Aktualisierter Zustand mit gespeicherten Daten und aktualisiertem Index.
"""
# Bedingung: Ist die aktuelle Zeit >= Zielzeitpunkt - Toleranz?
condition = t >= (state.target_times[state.current_save_idx] - self.tol)
def do_save(state):
# Speichere die aktuelle Zeit
state.t = state.t.at[state.current_save_idx].set(t)
# Speichere die aktuellen Zustände
for key in state.q:
state.q[key] = state.q[key].at[state.current_save_idx].set(q[key])
if user_data is not None:
for key in state.user:
state.user[key] = state.user[key].at[state.current_save_idx].set(user_data[key])
# Inkrementiere den Speicherschritt, aber clippe es auf num_points +1
state.current_save_idx = jnp.minimum(state.current_save_idx + 1, state.num_points)
return state
def do_nothing(state):
return state
new_state = jax.lax.cond(condition, do_save, do_nothing, state)
return new_state
def finalize(self, state):
return HistoryState(state.t, state.q, state.user)
[docs]@jax.tree_util.register_dataclass
@dataclass
class SaveAllHistoryState:
"""
State for the SaveAllPolicy.
"""
t_max: float
max_steps: int
t: dict[str, jnp.ndarray]
q: dict[str, jnp.ndarray]
current_save_idx: int
user: Any
[docs]class SaveAllPolicy(SavePolicy):
"""
Saves data at every accepted time step.
"""
def unpack_state(self, state):
return state.t_max, state.max_steps, state.t, state.q, state.current_save_idx
def initialize(self, q, t_max, max_steps, user_data={}):
"""Prepares the history state. Initializes the arrays with NaNs."""
history_t = jnp.full(max_steps + 1, jnp.nan)
history_q = {key: jnp.full((max_steps + 1,) + q[key].shape, jnp.nan) for key in q.keys()}
history_user = {key: jnp.full((max_steps + 1,) + user_data[key].shape, jnp.nan) for key in user_data.keys()}
current_save_idx = 0
return SaveAllHistoryState(t_max, max_steps, history_t, history_q, current_save_idx, user=history_user)
def save_step(self, state, t, q, user_data={}):
"""Save current state."""
state.t = state.t.at[state.current_save_idx].set(t)
for key in state.q:
state.q[key] = state.q[key].at[state.current_save_idx].set(q[key])
if user_data is not None:
for key in state.user:
state.user[key] = state.user[key].at[state.current_save_idx].set(user_data[key])
# Inkrementiere den Speicherschritt, aber clippe es auf max_steps
state.current_save_idx = jnp.minimum(state.current_save_idx + 1, state.max_steps)
return state
def finalize(self, state):
return HistoryState(state.t, state.q, state.user)
## Step size controler
[docs]class StepSizeController(ABC):
"""
Abstract base class for adaptive step size controllers.
"""
@abstractmethod
def initialize(self, initial_error):
"""
Initialize the controller state based on the initial error.
Args:
initial_error (float): The initial scaled error estimate.
Returns:
Any: The initial state of the controller.
"""
pass
@abstractmethod
def compute_scaler(self, error, q, q_n, state, order, converged, num_iterations, dt, verbose):
"""
Compute the scaling factor `step_scaler` and update the controller state.
Args:
error (float): The current scaled error estimate.
state (Any): The current state of the controller.
Returns:
Tuple[float, Any]: A tuple containing the scaling factor `step_scaler` and the updated state.
"""
pass
@abstractmethod
def check_accept(self, state, converged, verbose):
"""
Check whether the current step should be accepted.
Args:
state (Any): The current state of the controller.
Returns:
Bool: Whether the step should be accepted.
"""
pass
[docs]@jax.tree_util.register_dataclass
@dataclass
class PIDControllerState:
"""
State for the Proportional-Integral-Derivative Controller.
"""
step_scaler: float
e_n: float # error_n
e_nn: float # error_{n-1}
accept: bool
interrupt: bool
[docs]class PIDController(StepSizeController):
"""Proportional-Integral-Derivative (PID) Step Size Controller.
Inspired by https://docs.kidger.site/diffrax/api/stepsize_controller and https://docs.sciml.ai/DiffEqDocs/stable/extras/timestepping/
"""
[docs] def __init__(self,
pcoeff: float = 0.0,
icoeff: float = 1.0,
dcoeff: float = 0.0,
limiter: Any = None,
atol=1e-6,
rtol=1e-3):
"""
Initialize the PIDController.
Args:
pcoeff (float): The coefficient of the proportional part of the step size control.
icoeff (float): The coefficient of the integral part of the step size control.
dcoeff (float): The coefficient of the derivative part of the step size control.
limiter (callable): Limiter function. If None the limiter is set to `1.0 + jnp.arctan(x - 1.0)`.
atol (float): Absolute tolerance.
rtol (float): Relative tolerance.
"""
self.limiter = limiter if limiter is not None else (lambda x: 1.0 + jnp.arctan(x - 1.0))
self.atol = atol
self.rtol = rtol
self.pcoeff = pcoeff
self.icoeff = icoeff
self.dcoeff = dcoeff
def initialize(self, initial_error):
"""
TODO: Initialization of time increment? See diffrax
"""
return PIDControllerState(step_scaler=1., e_n=initial_error, e_nn=initial_error, accept=True, interrupt=False)
def compute_scaler(self, error, q, q_n, state, order, converged, num_iterations, dt, verbose):
keys = q.keys()
atol = self.atol if isinstance(self.atol, dict) else {key: self.atol for key in keys}
rtol = self.rtol if isinstance(self.rtol, dict) else {key: self.rtol for key in keys}
# Check that all integrators support error estimation (integrators that does not support return None as error)
for key in keys:
if error[key] is None:
raise ValueError(
f"\n\nError estimation is not supported for the integrator of field '{key}'!\n Use a different integrator or a time step control that does not require error estimation."
)
# Scaled error estimate
scaled_error = {
key: jnp.abs(jnp.divide(error[key], atol[key] + jnp.maximum(q[key], q_n[key][0]) * rtol[key])) for key in keys
}
scaled_error = dict_flatten(scaled_error)
# Hairer norm
inv_error_norm = 1 / jnp.sqrt(jnp.mean(scaled_error**2))
# PID coefficients
k = order + 1
beta1 = (self.pcoeff + self.icoeff + self.dcoeff) / k
beta2 = -(self.pcoeff + 2 * self.dcoeff) / k
beta3 = self.dcoeff / k
# PID rule
step_scaler = jnp.power(inv_error_norm, beta1) * jnp.power(state.e_n, beta2) * jnp.power(state.e_nn, beta3)
# Handle zero error
step_scaler = jnp.where(jnp.isinf(inv_error_norm), 1., step_scaler)
# Apply limiter function
step_scaler = self.limiter(step_scaler)
# Update state
state = PIDControllerState(step_scaler, state.e_n, state.e_nn, state.accept, state.interrupt)
# Update previous errors if step is accepted
accepted = self.check_accept(state, converged, verbose).accept
def do_accept(state):
return PIDControllerState(step_scaler, inv_error_norm, state.e_n, state.accept, state.interrupt)
def do_reject(state):
return PIDControllerState(step_scaler, state.e_n, state.e_nn, state.accept, state.interrupt)
state = jax.lax.cond(accepted, do_accept, do_reject, state)
return state
def check_accept(self, state, converged, verbose):
return PIDControllerState(state.step_scaler, state.e_n, state.e_nn, converged, state.interrupt)
[docs]@jax.tree_util.register_dataclass
@dataclass
class CSSControllerState:
"""
State for the Constant Step Size Controller.
"""
step_scaler: float
accept: bool
interrupt: bool
[docs]class ConstantStepSizeController(StepSizeController):
"""Constant Step Size Controller."""
[docs] def __init__(self):
pass
def initialize(self, initial_error):
return CSSControllerState(step_scaler=1., accept=True, interrupt=False)
def compute_scaler(self, error, q, q_n, state, order, converged, num_iterations, dt, verbose):
return state
def check_accept(self, state, converged, verbose):
# Give warning in case not converged and not already warned
def send_warning():
if verbose >= 0:
jax.debug.print("Root solver did not converge, but stepsize controller can not reduce step size!")
return True
do_nothing = lambda: state.interrupt
warn = jnp.logical_and(jnp.logical_not(converged), jnp.logical_not(state.interrupt))
interrupt = jax.lax.cond(warn, send_warning, do_nothing)
return CSSControllerState(step_scaler=state.step_scaler, accept=converged, interrupt=interrupt)
[docs]@jax.tree_util.register_dataclass
@dataclass
class RootIterationControllerState:
"""
State for the Root Iteration Controller.
"""
step_scaler: float # Scaling factor for the step size
dt: float
accept: bool
interrupt: bool
[docs]class RootIterationController(StepSizeController):
"""Root Iteration Controller for adjusting step size based on number of root solver iterations.
Tries to achieve a target number of Newton iterations by adjusting the step size.
Does not consider possible error estimates. Maximal and minimal step sizes can be set.
"""
[docs] def __init__(self,
target_niters: int = 6,
gamma: float = 0.5,
max_step_size: float = 1e20,
min_step_size: float = 1e-6):
"""
Initialize the RootIterationController.
Args:
target_niters (int): Desired number of Newton iterations.
gamma (float): Proportionality factor for step size adjustment.
max_step_size (float): Maximum allowable step size.
min_step_size (float): Minimum allowable step size.
"""
self.target_niters = target_niters
self.gamma = gamma
self.max_step_size = max_step_size
self.min_step_size = min_step_size
def initialize(self, initial_error):
return RootIterationControllerState(step_scaler=1.0, dt=1.0, accept=True, interrupt=False)
def compute_scaler(self, error, q, q_n, state, order, converged, num_iterations, dt, verbose):
# Proportional control based on the deviation from target_niters
correction = (1 + self.gamma * (self.target_niters - num_iterations) / self.target_niters)
# If not converged, devide step size by 2
correction = jax.lax.cond(converged, lambda x: jnp.astype(x, float), lambda x: 1 / 2, correction)
# Change step size
dt_old = dt
dt = dt * correction
# Limit step size to bounds
dt = jnp.clip(dt, self.min_step_size, self.max_step_size)
# Recalculate scaling factor
step_scaler = dt / dt_old
# Update state
return RootIterationControllerState(step_scaler=step_scaler, dt=dt, accept=state.accept, interrupt=state.interrupt)
def check_accept(self, state, converged, verbose):
# Give warning in case not converged and step size can not be reduced further
def send_warning():
if verbose >= 0:
jax.debug.print("Root solver did not converge, but minimum step_size is reached!")
return True
do_nothing = lambda: state.interrupt
warn = jnp.logical_and(jnp.logical_and(jnp.logical_not(converged), jnp.isclose(state.dt, self.min_step_size)),
jnp.logical_not(state.interrupt))
interrupt = jax.lax.cond(warn, send_warning, do_nothing)
return RootIterationControllerState(state.step_scaler, state.dt, converged, interrupt)
## Root solvers
[docs]@jax.tree_util.register_dataclass
@dataclass
class RootSolverResult:
root: Any
num_iterations: int
converged: bool
[docs]def newton_solver(
func,
x0,
atol=1e-8,
max_iter=20,
damping_factor=1,
tangent_fun=None,
lin_solve_fun=None,
constrained_dofs=None,
constrained_values=None,
verbose=0,
termination_mode='residual',
):
"""
Newton-Raphson solver to find a root of F(x)=0.
Args:
func: Function F(x) whose zero is sought.
x0: Initial guess.
atol: Absolute tolerance.
max_iter: Maximum number of iterations.
tangent_fun: Function to compute the Jacobian. (Default: jax.jacfwd(func))
lin_solve_fun: Function to solve the linear system. (Default: jnp.linalg.solve)
constrained_dofs: Boolean mask for fixed degrees of freedom.
constrained_values: Fixed values for constrained DOFs.
verbose: If >=1, prints the residual norm each iteration.
termination_mode: 'residual' uses the residual norm; 'update' uses the update size.
Returns:
A RootSolverResult dataclass with fields:
- .root: the computed solution,
- .num_iterations: number of updates performed,
- .converged: convergence flag.
"""
# Set constraints if provided.
free_dofs = None
if constrained_dofs is not None:
free_dofs = ~constrained_dofs
if constrained_values is None:
raise ValueError("constrained_values must be provided if constrained_dofs is not None!")
if constrained_dofs.shape != x0.shape:
raise ValueError("constrained_dofs must have the same shape as x0!")
if constrained_values.shape != x0.shape:
raise ValueError("constrained_values must have the same shape as x0!")
# Enforce constraints in the initial guess.
x0 = jnp.where(constrained_dofs, constrained_values, x0)
if lin_solve_fun is None:
if constrained_dofs is None:
def lin_solve_fun(J, b, free_dofs):
return jnp.linalg.solve(J, b)
else:
def lin_solve_fun(J, b, free_dofs):
I = jnp.eye(J.shape[0])
J_mod = jnp.where(constrained_dofs[:, None], I, J)
return jnp.linalg.solve(J_mod, b)
use_residual = termination_mode == 'residual'
def body(i, state):
x, count, fx_norm, converged, stop = state
def update_fn(_):
# Apply boundary conditions
x_mod = x if constrained_dofs is None else jnp.where(constrained_dofs, constrained_values, x)
# Residual and tangent
if tangent_fun is not None:
fx = func(x_mod)
J = tangent_fun(x_mod)
else:
# Option useing jvp + vmap.
# Create the standard basis for the tangent space.
eye = jnp.eye(x_mod.shape[0])
# Map jax.jvp over the identity basis.
# Each call returns (f(x_mod), f'(x_mod) @ v)
primals, tangents = jax.vmap(lambda v: jax.jvp(func, (x_mod,), (v,)))(eye)
# All primals are identical; extract the first one.
fx = primals[0]
# Each tangent is one row of the Jacobian, so we transpose.
J = tangents.T
# Apply boundary conditions
fx_mod = fx if constrained_dofs is None else jnp.where(constrained_dofs, 0.0, fx)
fx_norm = jnp.linalg.norm(fx_mod)
# Stop newton iterations if NaN or Inf values are encountered.
stop = jnp.any(jnp.logical_or(jnp.isnan(fx_norm), jnp.isinf(fx_norm)))
# In 'residual' mode, only call the linear solver if fx_norm >= atol.
if use_residual:
delta = jax.lax.cond(
# Ensure one iteration in case of constrained dofs
fx_norm < atol if constrained_dofs is None else jnp.logical_and(fx_norm < atol, count >= 0),
lambda: jnp.zeros_like(x_mod),
lambda: lin_solve_fun(J, -fx_mod, free_dofs))
else:
delta = lin_solve_fun(J, -fx_mod, free_dofs)
x_new = x_mod + damping_factor * delta
new_converged = fx_norm < atol if use_residual else jnp.linalg.norm(delta) < atol
new_count = count + 1 # increment only when update_fn is executed
if verbose >= 1:
# Print using the update counter rather than the loop counter.
jax.debug.print("Iteration {iter}, Residual norm: {res}", iter=new_count, res=fx_norm, ordered=True)
return (x_new, new_count, fx_norm, new_converged, stop)
# If already converged, carry the state unchanged.
new_state = jax.lax.cond(jnp.logical_or(converged, stop), lambda _: state, update_fn, operand=None)
return new_state
state = (x0, -1, jnp.inf, False, False)
# Define condition for while_loop: iterate while iterations remain and not all instances are converged or stopped.
def cond_fn(state):
_, count, _, converged, stop = state
return jnp.logical_and(count < max_iter - 1,
jnp.logical_not(jnp.all(jnp.logical_or(converged, stop))))
final_state = jax.lax.while_loop(cond_fn, lambda state: body(0, state), state)
x_final, iterations, _, conv, _ = final_state
# In case of invalid values, fallback to the initial guess and stop the gradients.
x_final = jnp.where(jnp.logical_or(~conv, jnp.logical_or(jnp.isnan(x_final), jnp.isinf(x_final))),
x0, x_final)
return RootSolverResult(x_final, iterations, conv)
## Time stepping manager
[docs]@jax.tree_util.register_dataclass
@dataclass
class TimeSteppingManagerState:
"""
State for the TimeSteppingManager.
Attributes:
q (dict[str, jnp.ndarray]): Final state variables after time stepping.
settings (dict[str, Any]): Simulation settings after the run.
history (Any): Recorded history data (if a save policy is used).
num_steps (int): Total number of steps taken.
num_accepted (int): Number of accepted time steps.
num_rejected (int): Number of rejected time steps.
"""
q: dict[str, jnp.ndarray]
settings: dict[str, Any]
history: Any
num_steps: int
num_accepted: int
num_rejected: int
[docs]class TimeSteppingManager:
"""
Manages the time stepping procedure for a simulation using multi-stage integration schemes.
This class orchestrates the simulation by coordinating various components such as:
- Time integrators for different fields,
- A root solver for implicit equations,
- An adaptive step size controller,
- A save policy for recording history,
- Pre-step and post-step update functions for custom processing.
The manager initializes the simulation state from given degrees of freedom (DOFs) and then runs a loop
over a specified number of time steps. At each step, it computes the new state using multi-stage methods,
applies error control and adaptive time stepping, and optionally records simulation history. The final state
along with simulation statistics is returned.
"""
[docs] def __init__(
self,
static_settings,
settings={'current time': 0.0},
root_solver=newton_solver,
save_policy=None,
step_size_controller=ConstantStepSizeController(),
postprocessing_fun=lambda q_fun, t, settings: {},
pre_step_updates=None,
post_step_updates=None,
):
self.integrators = static_settings['time integrators']
self.root_solver = root_solver
self.num_time_derivs = {key: integrator.num_derivs for key, integrator in self.integrators.items()}
self.num_steps = {key: integrator.num_steps for key, integrator in self.integrators.items()}
self.save_policy = save_policy
self.step_size_controller = step_size_controller
self.postprocessing_fun = postprocessing_fun
self.static_settings = static_settings
if pre_step_updates is None:
# Update the current time
def pre_step_updates(t, settings):
settings['current time'] = t
return settings
self.pre_step_updates = pre_step_updates
if post_step_updates is None:
def post_step_updates(q_fun, t, settings):
return settings
self.post_step_updates = post_step_updates
self.verbose = static_settings.get('verbose', 0)
@staticmethod
def _tree_flatten(obj):
children = ()
aux_data = (obj.integrators, obj.num_time_derivs, obj.num_steps, obj.root_solver, obj.save_policy,
obj.step_size_controller, obj.postprocessing_fun, obj.static_settings, obj.verbose,
obj.pre_step_updates, obj.post_step_updates)
return (children, aux_data)
@staticmethod
def _tree_unflatten(aux_data, children):
obj = object.__new__(TimeSteppingManager)
() = children
(obj.integrators, obj.num_time_derivs, obj.num_steps, obj.root_solver, obj.save_policy,
obj.step_size_controller, obj.postprocessing_fun, obj.static_settings, obj.verbose,
obj.pre_step_updates, obj.post_step_updates) = aux_data
return obj
def _initialize(self, dofs):
"""
Initializes history for multi-step methods and stores the global DOF structure
as a template for unflattening.
"""
q_n = {key: jnp.repeat(dofs[key][None, ...], self.num_steps[key], axis=0) for key in dofs}
q_der_n = {key: jnp.zeros((self.num_steps[key], self.num_time_derivs[key], *dofs[key].shape)) for key in dofs}
self._global_template = dofs
return q_n, q_der_n
def _assemble_sparse_tangent(self, x_flat, q_n, q_t_n, dt, t, current_stages, settings, static_settings):
num_domains = len(static_settings["assembling mode"])
num_dofs = sum(v.size for v in self._global_template.values())
integrated_tangent = sparse.empty((num_dofs, num_dofs), dtype=float, index_dtype=jnp.int_)
# Make sure all assembling modes are 'user residual' as it is the only one currently supported
assert all(item == 'user residual'
for item in static_settings["assembling mode"]), "Only 'user residual' assembling mode is supported."
# Loop over all sets of integration points/ domains
for domain in range(num_domains):
integrated_tangent += self._assemble_sparse_tangent_domain(x_flat, q_n, q_t_n, dt, t, current_stages, settings,
static_settings, domain)
return integrated_tangent
def _assemble_sparse_tangent_domain(self, x_flat, q_n, q_t_n, dt, t, current_stages, settings, static_settings, domain):
# Todo: andere modes? e.g. for potential-based problems (like user potentials in assembler)
# todo: currently only single stages and blocks possible
# Reconstruct global DOF structure from flat vector
global_dofs = reshape_as(x_flat, self._global_template)
# Get elementwise quantities (model_fun, node coordinates, elem_numbers, connectivity)
model_fun, x_nodes, elem_numbers, connectivity = assembler._get_element_quantities(
global_dofs, settings, static_settings, domain)
def extract_dofs(dofs, node_list, axis=0):
return jax.tree.map(lambda x, y: jnp.take(x, y, axis=axis), dofs, node_list)
# Calculate the tangent for each element
def element_tangent_wrapper(local_dofs, elem_number, node_list):
# Extract local DOFs for the current element
local_q_n = extract_dofs(q_n, node_list, axis=1)
local_q_t_n = extract_dofs(q_t_n, node_list, axis=2)
def diffable_q_fun(t):
local_diffable = {}
for key in local_dofs:
tup = self.integrators[key].value_and_derivatives(local_dofs[key], local_q_n[key], local_q_t_n[key], dt)
local_diffable[key] = discrete_value_with_derivatives(t, tup[0], tup[1:]) # todo: handle multiple blocks
return local_diffable
local_node_coor = extract_dofs(x_nodes, node_list)
return model_fun(diffable_q_fun, local_node_coor, elem_number, settings, static_settings, domain)
def element_tangent(elem_number, node_list):
local_dofs = extract_dofs(global_dofs, node_list)
return jax.jacfwd(lambda x: element_tangent_wrapper(x, elem_number, node_list))(local_dofs)
# tangent_contributions = jax.vmap(element_tangent, in_axes=(0, 0))(elem_numbers, connectivity)
body_fun = lambda i: element_tangent(elem_numbers[i], jax.tree.map(lambda x: x[i], connectivity))
num_dofs_per_elem = jax.eval_shape(lambda i: dict_flatten(body_fun(i)), 0).shape[0]
tangent_contributions = jax.lax.map(body_fun, jnp.arange(elem_numbers.shape[0]), batch_size=int(64000/num_dofs_per_elem))
data = dict_flatten(tangent_contributions)
indices = assembler._get_indices(connectivity, global_dofs)
if isinstance(global_dofs, dict):
num_dofs = sum(v.size for v in global_dofs.values())
else:
num_dofs = global_dofs.size
tangent_matrix = sparse.BCOO((data, indices), shape=(num_dofs, num_dofs))
return tangent_matrix
def _multi_stage_step(self, q, t, t_n, dt, q_n, q_t_n, settings):
# Assumption: All integrators have the same number of stages
num_stages = next(iter(self.integrators.values())).num_stages
template = jax.lax.stop_gradient(q)
stage_list = next(iter(self.integrators.values())).stage_list
stage_positions = next(iter(self.integrators.values())).stage_positions
num_blocks = stage_list.shape[0]
q_stages = {key: jnp.repeat(q[key][None, ...], num_stages, axis=0) for key in template.keys()}
# Run through all stages
converged = True
num_iterations = 0
state_init = (q_stages, num_iterations, converged, settings)
def block_body(block_number, state):
q_stages, num_iterations, converged, settings = state
current_stages = stage_list[block_number]
num_coupled_stages = current_stages.shape[0]
t_stage = t_n + dt * stage_positions[current_stages]
# Update e.g. boundary conditions and 'current time' for the assembler for PDE problems
settings = self.pre_step_updates(t_stage[0], settings)
# Getting some settings for solver and tangent specific settings
dirichlet_dofs = settings.get('dirichlet dofs', None)
dirichlet_conditions = settings.get('dirichlet conditions', None)
if dirichlet_dofs is not None:
assert dirichlet_conditions is not None, "Constrained values must be provided if constrained dofs are given."
dirichlet_dofs = dict_flatten(dirichlet_dofs)
dirichlet_conditions = dict_flatten(dirichlet_conditions)
free_dofs = ~dirichlet_dofs if dirichlet_dofs is not None else None
solver_backend = self.static_settings.get('solver backend', None)
solver_subtype = self.static_settings.get('solver', None)
verbose = self.static_settings.get('verbose', 0)
impl_diff_mode = self.static_settings.get('implicit diff mode', 'forward')
# Custom tangent via assembling for sparse problems
tangent_fun = None
# Todo: make it diffable via custom root wrapper as in solver.adaptive_load_stepping
if solver_backend is None:
lin_solve_fun = None
elif solver_backend == 'pardiso':
def lin_solve_fun(mat, rhs, free_dofs):
callback_fun = lambda mat, rhs, free_dofs: solver.linear_solve_pardiso(
mat, rhs, solver=solver_subtype, verbose=verbose, free_dofs=free_dofs)
return jax.pure_callback(callback_fun,
jnp.zeros(rhs.shape, rhs.dtype),
mat,
rhs,
free_dofs,
vmap_method='sequential')
elif solver_backend == 'scipy':
def lin_solve_fun(mat, rhs, free_dofs):
callback_fun = lambda mat, rhs, free_dofs: solver.linear_solve_scipy(
mat, rhs, free_dofs=free_dofs, solver=solver_subtype, verbose=verbose)
return jax.pure_callback(callback_fun,
jnp.zeros(rhs.shape, rhs.dtype),
mat,
rhs,
free_dofs,
vmap_method='sequential')
else:
raise ValueError(f"Unknown solver backend: {solver_backend}")
# @partial(jax.jit, inline=True)
@jax.jit
def residual_fun_flat(x, settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs):
# print("Traced residual")
q_stages_current = {key: q_stages[key] for key in template.keys()}
x = jnp.reshape(x, (num_coupled_stages, -1))
def inner_body(s, q_sc):
q_s = reshape_as(x[s], template)
for key in template.keys():
q_sc[key] = q_sc[key].at[current_stages[s]].set(q_s[key])
return q_sc
if num_coupled_stages == 1:
q_stages_current = inner_body(0, q_stages_current)
else:
q_stages_current = jax.lax.fori_loop(0, num_coupled_stages, inner_body, q_stages_current)
def _residual_fun(s, t):
q_ts = {}
for key in template.keys():
tup = self.integrators[key].value_and_derivatives(q_stages_current[key], q_n[key], q_t_n[key], dt)
q_ts[key] = tuple(array[s] for array in tup)
def diffable_q_fun(t):
return {key: discrete_value_with_derivatives(t, q_ts[key][0], q_ts[key][1:]) for key in template.keys()}
if self.static_settings['dae'] == 'call pde':
return dict_flatten(assembler.assemble_residual(diffable_q_fun, settings, self.static_settings))
else:
return self.static_settings['dae'](diffable_q_fun, t, settings)
if current_stages.shape[0] == 1:
return _residual_fun(current_stages[0], t_stage[0]).flatten()
else:
residual_fun_vmap = jax.vmap(lambda s, t: _residual_fun(s, t), (0, 0))
return residual_fun_vmap(current_stages, t_stage).flatten()
q_flat = dict_flatten(q)
q_flat = jnp.tile(q_flat[None, ...], (num_coupled_stages, 1)).flatten()
# Root solve
if solver_backend is None:
def root_solve(fun, x0):
result = self.root_solver(
fun,
x0,
tangent_fun=tangent_fun,
constrained_dofs=dirichlet_dofs,
constrained_values=dirichlet_conditions,
lin_solve_fun=lin_solve_fun,
verbose=verbose,
)
float_iterations = jnp.astype(result.num_iterations, float)
float_conv = jnp.astype(result.converged, float)
return result.root, (float_iterations, float_conv)
if impl_diff_mode in ('forward', 'reverse', 'backward'):
# Todo: doesn't work for derivatives w.r.t. BCs, use custom_root decorator instead (has to be extended for dense matrices)
assert dirichlet_dofs is None, "Dirichlet BCs are not supported in implicit diff modes 'forward', 'reverse', 'backward'. Use None instead or use sparse solvers, e.g. 'pardiso' or 'scipy'."
q_root, (float_iterations, float_converged) = jax.lax.custom_root(
f=lambda x: residual_fun_flat(x, settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs),
initial_guess=jax.lax.stop_gradient(q_flat),
solve=root_solve,
tangent_solve=lambda g, y: jnp.linalg.solve(jax.jacfwd(g)(y), y),
has_aux=True)
else:
# Works, but may be not as efficient as with custom_root. Can sometimes produce nans in reverse mode.
q_root, (float_iterations, float_converged) = root_solve(lambda x: residual_fun_flat(x, settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs), jax.lax.stop_gradient(q_flat))
elif solver_backend in ['pardiso', 'scipy', 'pyamg', 'petsc']:
assert self.static_settings['dae'] == 'call pde', "Only 'call pde' mode is supported for sparse solvers."
@jax.jit
def tangent_fun(x, settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs):
# print("Traced tangent")
return self._assemble_sparse_tangent(x, q_n, q_t_n, dt, t_stage, current_stages, settings, self.static_settings)
# Implicit diff of newton solver
@implicit_diff.custom_root(
residual_fun=residual_fun_flat,
mat_fun=tangent_fun,
solve=lin_solve_fun,
free_dofs=free_dofs,
has_aux=True,
mode=impl_diff_mode)
def root_solve(x0, settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs):
result = self.root_solver(
lambda x: residual_fun_flat(x, settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs),
x0,
tangent_fun=lambda x: tangent_fun(x, settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs),
constrained_dofs=dirichlet_dofs,
constrained_values=dirichlet_conditions,
lin_solve_fun=lin_solve_fun,
verbose=verbose,
)
float_iterations = jnp.astype(result.num_iterations, float)
float_conv = jnp.astype(result.converged, float)
return result.root, (float_iterations, float_conv)
q_root, (float_iterations, float_converged) = root_solve(jax.lax.stop_gradient(q_flat), settings, q_stages, q_n, q_t_n, current_stages, t_stage, dt, template, dirichlet_conditions, dirichlet_dofs)
else:
raise ValueError(f"Unknown solver backend: {solver_backend}")
num_iterations = jnp.maximum(jnp.astype(float_iterations, jnp.int32), num_iterations)
converged = jnp.logical_and(jnp.astype(float_converged, jnp.bool), converged)
q_root = jnp.reshape(q_root, (num_coupled_stages, -1))
def update_body(s, q_stg):
q_s = reshape_as(q_root[s], q)
for key in template.keys():
q_stg[key] = q_stg[key].at[current_stages[s]].set(q_s[key])
return q_stg
if num_coupled_stages == 1:
q_stages = update_body(0, q_stages)
else:
q_stages = jax.lax.fori_loop(0, num_coupled_stages, update_body, q_stages)
return (q_stages, num_iterations, converged, settings)
if num_blocks == 1:
q_stages, num_iterations, converged, settings = block_body(0, state_init)
else:
q_stages, num_iterations, converged, settings = jax.lax.fori_loop(0, num_blocks, block_body, state_init)
# Call integrators for updating the solution and derivatives
q_n1 = {}
q_t_n1 = {}
for key in template.keys():
q_n1[key], q_t_n1[key] = self.integrators[key].update(q_stages[key], q_n[key], q_t_n[key], dt)
error_estimate = {
key: self.integrators[key]._error_estimate(q_stages[key], q_n[key], q_t_n[key], dt) for key in template.keys()
}
return q_n1, q_t_n1, error_estimate, num_iterations, converged, settings
[docs] @jit_with_docstring(static_argnames=['num_time_steps'])
def run(self, dofs, dt0, t_max, num_time_steps, settings = {'current time': 0.0}):
"""
Executes the time stepping loop for the simulation.
This method performs the following operations:
1. Verifies that all time integrators are compatible (i.e., they have the same number of stages,
identical stage positions, and stage lists).
2. Initializes simulation variables including the initial DOFs, time (t), step size (dt), and history state.
3. Iteratively performs time steps using a multi-stage integration method:
- Updates state with pre-step modifications.
- Computes the new state and derivative estimates using the multi-stage step procedure.
- Estimates the error and uses the step size controller to adjust dt.
- Accepts or rejects the time step based on convergence criteria.
- Optionally saves the current state using the save policy.
- Performs post-step updates to settings.
4. Continues the loop until the simulation time reaches t_max or the maximum number of time steps is reached.
5. Finalizes and returns the simulation state along with statistics such as the number of accepted
and rejected steps.
Args:
dofs (dict[str, jnp.ndarray]): Initial degrees of freedom for the simulation.
dt0 (float): Initial time step size.
t_max (float): Maximum simulation time.
num_time_steps (int): Maximum number of time steps to perform.
settings (dict[str, Any]): Dictionary containing dynamic simulation settings. Default is {'current time': 0.0}.
Returns:
TimeSteppingManagerState: An object containing the final state (q), updated settings, simulation history, and step statistics (total steps, accepted steps, rejected steps).
"""
# Check whether time integrators are compatible
assert all(integrator.num_stages == next(iter(self.integrators.values())).num_stages for integrator in self.integrators.values()),\
"Number of stages must be the same for all fields."
assert all(np.allclose(integrator.stage_positions, next(iter(self.integrators.values())).stage_positions) for integrator in self.integrators.values()),\
"Stage positions must be the same for all fields."
assert all(np.allclose(integrator.stage_list, next(iter(self.integrators.values())).stage_list) for integrator in self.integrators.values()),\
"Stage list must be the same for all fields."
# Alphabetic keywords
assert list(dofs.keys()) == sorted(dofs.keys()), "The keys of the DOFs must be alphabetically sorted."
# Some initializations
dt = dt0
t = 0.
q = dofs
q_n, q_der_n = self._initialize(q)
controler_state = self.step_size_controller.initialize(0.0)
# Todo: get optional start values for derivatives for higher order problems. currently just set to zero...
def diffable_q_fun(t):
return {key: discrete_value_with_derivatives(t, q[key], q_der_n[key][0]) for key in q.keys()}
user_data = self.postprocessing_fun(diffable_q_fun, t, settings)
history_state = (self.save_policy.initialize(q, t_max, num_time_steps, user_data)
if self.save_policy is not None else None)
history_state = (self.save_policy.save_step(history_state, t, q, user_data)
if self.save_policy is not None else None)
# Prepare the function for one time step
def loop_body(step, state):
def step_fun(state):
t, t_n, dt, _, q_n, q_der_n, history_state, controler_state, num_accepted, num_rejected, settings, last_printed, t_max = state
t = t_n + dt
t = jnp.minimum(t, t_max)
dt = t - t_n
q = {key: q_n[key][0] for key in q_n.keys()}
# Run step
settings = self.pre_step_updates(t, settings)
q, q_der, error_estimate, num_iterations, converged, settings = self._multi_stage_step(
q, t, t_n, dt, q_n, q_der_n, settings)
# Call time step controler
order = min([integrator.order for integrator in self.integrators.values()])
controler_state = self.step_size_controller.compute_scaler(
error_estimate, q, q_n, controler_state, order, converged, num_iterations, dt, self.verbose
)
controler_state = self.step_size_controller.check_accept(controler_state, converged, self.verbose)
# Here the step size controling logic is not taken into account for the derivatives
# Todo: check whether integrators support changing the step size
controler_state = jax.lax.stop_gradient(controler_state)
dt_scaler = controler_state.step_scaler
accept = controler_state.accept
interrupt = controler_state.interrupt
accept = jnp.logical_and(accept, jnp.logical_not(interrupt))
def do_accept(x):
# Update history data and user-defined postprocessing data and adjust step size
(history_state, q, q_n, q_der_n, t, t_n, dt, num_a, num_r, settings) = x
for key in q.keys():
q_n[key] = jnp.roll(q_n[key], shift=1, axis=0)
q_n[key] = q_n[key].at[0].set(q[key])
q_der_n[key] = jnp.roll(q_der_n[key], shift=1, axis=0)
q_der_n[key] = q_der_n[key].at[0].set(q_der[key])
def diffable_q_fun(t):
return {key: discrete_value_with_derivatives(t, q[key], q_der[key]) for key in q.keys()}
user_data = self.postprocessing_fun(diffable_q_fun, t, settings)
settings = self.post_step_updates(diffable_q_fun, t, settings)
history_state = self.save_policy.save_step(history_state, t, q, user_data) \
if self.save_policy is not None else history_state
dt = dt_scaler * dt
t_n = t
return history_state, q_n, q_der_n, t, t_n, dt, num_a + 1, num_r, settings
def do_reject(x):
# Return to old step and reduce step size
(history_state, _, q_n, q_der_n, t, t_n, dt, num_a, num_r, settings) = x
t = t - dt
dt = dt_scaler * dt
return (history_state, q_n, q_der_n, t, t_n, dt, num_a, num_r + 1, settings)
history_state, q_n, q_der_n, t, t_n, dt, num_accepted, num_rejected, settings = jax.lax.cond(accept, do_accept, do_reject,
(history_state, q, q_n, q_der_n, t, t_n, dt, num_accepted, num_rejected, settings))
# # debug accept/reject logic
# jax.debug.print("Accept: {x}", x=accept, ordered=True)
# jax.debug.print("dt: {x}", x=dt, ordered=True)
# jax.debug.print("t: {x}", x=t, ordered=True)
# jax.debug.print("t_n: {x}", x=t_n, ordered=True)
# jax.debug.print("q: {x}", x=q, ordered=True)
# jax.debug.print("q_n: {x}", x=q_n, ordered=True)
if self.verbose >= 0:
progress = (100 * t / t_max).astype(int)
should_print = jnp.greater_equal(progress - last_printed, 5)
def print_and_update(_):
jax.debug.print("Progress: {a}%, Time: {b:.2e}, accepted step: {d}, dt: {c:.2e}, iterations: {e}",
a=progress,
b=t,
c=dt,
d=accept,
e=num_iterations,
ordered=True)
return progress
last_printed = jax.lax.cond(should_print, print_and_update, lambda _: last_printed, operand=None)
if self.verbose >= 1:
jax.debug.print(" ", ordered=True)
return (t, t_n, dt, q, q_n, q_der_n, history_state, controler_state, num_accepted, num_rejected, settings, last_printed, t_max)
def do_nothing(state):
return state
t = state[0]
t_max = state[-1]
interrupt = state[7].interrupt
state = jax.lax.cond(jnp.logical_and(t < t_max, jnp.logical_not(interrupt)), step_fun, do_nothing, state)
# In case of interruption, set all values in q to nan
interrupt = state[7].interrupt
state = (*state[:3],
jax.lax.cond(interrupt, lambda x: jax.lax.stop_gradient(jax.tree.map(lambda a: jnp.full_like(a, jnp.nan), x)), lambda x: x, state[3]),
*state[4:])
return state
# Run time stepping loop
num_accepted = 0
num_rejected = 0
initial_state = (t, 0., dt, q, q_n, q_der_n, history_state, controler_state, num_accepted, num_rejected, settings, -2, t_max)
if num_time_steps == 1:
final_state = loop_body(0, initial_state)
else:
final_state = jax.lax.fori_loop(0, num_time_steps, loop_body, initial_state)
t, _, _, q, _, _, history_state, controler_state, num_accepted, num_rejected, settings, _, _ = final_state
if self.verbose >= 0:
jax.lax.cond(jnp.isclose(t, t_max),
lambda _: None,
lambda _: jax.debug.print("Maximum number of steps reached before t_max!"),
operand=None)
history_state = self.save_policy.finalize(history_state) if self.save_policy is not None else history_state
return TimeSteppingManagerState(
q=q,
settings=settings,
history=history_state,
num_steps=num_accepted + num_rejected,
num_accepted=num_accepted,
num_rejected=num_rejected,
)
# Register as pytree node in order to be able to jit the methods
tree_util.register_pytree_node(TimeSteppingManager, TimeSteppingManager._tree_flatten,
TimeSteppingManager._tree_unflatten)