# solver.py
# Copyright (C) 2024 Tobias Bode
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
"""
This module is the central module of the analysis phase.
Based on the given entries in settings and static_settings, the functions solver and adaptive_load_stepping can be used to find the roots of the global residual vector.
Depending on the settings, linear equation solvers, the Newton-Raphson method, or nonlinear minimizers are utilized.
The residual vectors and (in the case of external solvers) the tangent matrix are automatically assembled according to the chosen settings.
The solver module uses the assembler, which in turn calls the variational_scheme, the solution_structure, and the spaces modules.
Additionally, automatic implicit differentiation in forward or reverse mode via the implicit_diff module is provided for the adaptive_load_stepping function.
For solving the linear equation systems, wrappers for different backends on CPU and GPU are available, including Pardiso and PETSc.
"""
import time
import sys
import jax
import jaxopt
import jax.numpy as jnp
from jax.experimental import sparse
from jax import lax
import numpy as np
import scipy as scp
from flax.core import FrozenDict
from autopdex import assembler, implicit_diff, utility
### Solvers as specified by the static_settings and settings
[docs]@utility.jit_with_docstring(static_argnames=["static_settings", "**kwargs"])
def solver(dofs, settings, static_settings, **kwargs):
"""
General solver function to solve a given problem based on provided settings.
This function chooses and runs the appropriate solver type (e.g., minimization, linear, Newton) based on the provided
`static_settings` and returns the solution and any additional information.
Args:
dofs (jnp.ndarray or dict): Degrees of freedom or initial guess for the solution.
settings (dict): Dictionary containing various settings and parameters required for assembling the problem.
static_settings (dict): Dictionary containing static settings such as solver type, verbose level, and variational schemes.
**kwargs (dict): Additional keyword arguments passed to the specific solver functions.
Returns:
jnp.ndarray: The solution obtained from the selected solver.
Any: Additional information from the solver, such as number of iterations or convergence status.
Solver Types:
- 'minimize' : Uses nonlinear minimization solvers (e.g., LBFGS, BFGS, etc.).
- 'linear' : Solves linear systems using specified backend (e.g., JAX, PETSc, PARDISO, PyAMG, Scipy).
- 'diagonal linear' : Solves linear systems assuming a diagonal tangent matrix.
- 'newton' : Uses the Newton method for solving nonlinear systems.
- 'damped newton' : Uses a damped Newton method for solving nonlinear systems.
For the different solvers, the function conducts the following functions in which more documentation is provided:
- 'minimize' : solver.solve_nonlinear_minimization
- 'linear' : solver.solve_linear
- 'diagonal linear' : solver.solve_diagonal_linear
- 'newton' : solver.solve_newton
- 'damped newton' : solver.solve_damped_newton
Notes:
- If all domains are using 'least square pde loss' variational scheme and verbosity level is >=1, it prints the L2 error before and after optimization.
"""
# Give global error estimator with unfitted dofs if all domains are 'least square pde loss'
verbose = static_settings["verbose"]
try:
if (
all(
[
scheme == "least square pde loss"
for scheme in static_settings["variational scheme"]
]
)
and verbose >= 1
):
jax.debug.print(
"L2 error unoptimized: {x}",
x=assembler.integrate_functional(dofs, settings, static_settings),
)
except KeyError:
pass
# Choose type of solver and call it
solver_type = static_settings["solver type"]
match solver_type:
case "minimize":
sol = solve_nonlinear_minimization(
dofs, settings, static_settings, **kwargs
)
infos = None
case "linear":
sol = solve_linear(dofs, settings, static_settings, **kwargs)
infos = None
case "diagonal linear":
sol = solve_diagonal_linear(dofs, settings, static_settings, **kwargs)
infos = None
# case 'diagonal newton':
# # ToDo
# sol, infos = solve_diagonal_newton(dofs, settings, static_settings, **kwargs)
case "newton":
sol, infos = solve_newton(dofs, settings, static_settings, **kwargs)
case "damped newton":
sol, infos = solve_damped_newton(dofs, settings, static_settings, **kwargs)
case _:
assert False, "Solver type not implemented!"
# Give global error estimator with fitted dofs if all domains are 'least square pde loss'
try:
if (
all(
[
scheme == "least square pde loss"
for scheme in static_settings["variational scheme"]
]
)
and verbose >= 1
):
jax.debug.print(
"L2 error optimized: {x}",
x=assembler.integrate_functional(sol, settings, static_settings),
)
except KeyError:
pass
return sol, infos
[docs]@utility.jit_with_docstring(
static_argnames=[
"static_settings",
"multiplier_settings",
"path_dependent",
"implicit_diff_mode",
"max_load_steps",
"max_multiplier",
"min_increment",
"max_increment",
"init_increment",
"target_num_newton_iter",
"newton_tol",
"**kwargs",
]
)
def adaptive_load_stepping(
dofs,
settings,
static_settings,
multiplier_settings=lambda settings, multiplier: (
settings.update({"load multiplier": multiplier}),
settings,
)[1],
path_dependent=True,
implicit_diff_mode=None,
max_multiplier=1.0,
min_increment=0.01,
max_increment=1.0,
init_increment=0.2,
max_load_steps=1000,
target_num_newton_iter=7,
newton_tol=1e-10,
**kwargs,
):
"""
Performs adaptive load stepping to solve a nonlinear system of equations.
This function iteratively adjusts the load increment to ensure convergence
using a Newton-Raphson solver. The increment size is adaptively controlled
based on the convergence behavior of the solver. Works currently only with
solver types 'newton' and 'damped newton'.
Args:
dofs (jnp.ndarray or dict): Initial degrees of freedom.
settings (dict): Dictionary of problem settings.
static_settings (dict): Dictionary of static settings that do not change during load steps.
multiplier_settings (callable): Function to update settings based on the current load multiplier.
path_dependent (bool): Specifies wether problem is path-dependent (experimental) or not (has an influence on the implicit differentiation).
implicit_diff_mode (string): Can be either \'reverse\', \'forward\' or None. In case of \'reverse\', only reverse mode differentiation is supported (jacrev), in case of \'forward\', only forward mode differentiation is supported (jacfwd).
max_multiplier (float): Maximum value for the load multiplier.
min_increment (float): Minimum allowable increment size.
max_increment (float): Maximum allowable increment size.
init_increment (float): Initial increment size.
max_load_steps (int): Maximal number of load steps. Only used in case implicit_diff_mod is not None
target_num_newton_iter (int): Target number of Newton iterations for each load step.
newton_tol (float, optional): Tolerance for Newton solver convergence. Default is 1e-10.
**kwargs: Additional keyword arguments for the solver.
Returns:
- jnp.ndarray: Solution degrees of freedom after load stepping.
"""
verbose = static_settings["verbose"]
if implicit_diff_mode is not None:
# Set-up the decorator for implicit differentiation
residual_fun = lambda dofs, settings: assembler.assemble_residual(dofs, settings, static_settings)
tangent_fun = lambda dofs, settings: assembler.assemble_tangent(dofs, settings, static_settings)
try:
dirichlet_dofs = settings["dirichlet dofs"]
except KeyError:
# Warning if it was defined in static_settings
assert "dirichlet dofs" not in static_settings, \
"'dirichlet dofs' has been moved to 'settings' in order to reduce compile time. \
Further, you should not transform it to a tuple of tuples anymore."
pass
free_dofs = None
free_dofs_flat = None
if dirichlet_dofs is not None:
free_dofs = utility.mask_op(dirichlet_dofs, utility.dict_ones_like(dirichlet_dofs), mode="apply", ufunc=lambda x: ~x)
free_dofs_flat = utility.dict_flatten(free_dofs)
is_constrained = True if free_dofs is not None else False
solver_backend = static_settings["solver backend"]
solver_subtype = static_settings["solver"]
try:
sensitivity_solver_backend = static_settings["sensitivity solver backend"]
sensitivity_solver_subtype = static_settings["sensitivity solver"]
except KeyError:
sensitivity_solver_backend = solver_backend
sensitivity_solver_subtype = solver_subtype
match sensitivity_solver_backend:
case "petsc":
n_fields = static_settings["number of fields"]
pc_type = static_settings["type of preconditioner"]
lin_solve_fun = lambda mat, rhs, free_dofs_flat: linear_solve_petsc(
mat,
rhs,
n_fields,
sensitivity_solver_subtype,
pc_type,
verbose,
free_dofs_flat,
**kwargs,
)
case "pardiso":
lin_solve_fun = lambda mat, rhs, free_dofs_flat: linear_solve_pardiso(
mat,
rhs,
sensitivity_solver_subtype,
verbose,
free_dofs_flat,
**kwargs,
)
case "pyamg":
pc_type = static_settings["type of preconditioner"]
lin_solve_fun = lambda mat, rhs, free_dofs_flat: linear_solve_pyamg(
mat,
rhs,
sensitivity_solver_subtype,
pc_type,
verbose,
free_dofs_flat,
**kwargs,
)
case "scipy":
lin_solve_fun = lambda mat, rhs, free_dofs_flat: linear_solve_scipy(
mat,
rhs,
sensitivity_solver_subtype,
verbose,
free_dofs_flat,
**kwargs,
)
case _:
raise ValueError(
"Specified sensitivity solver backend not available. Choose 'pardiso', 'petsc', 'pyamg' or 'scipy'."
)
def lin_solve_callback_fun(mat, rhs, free_dofs_flat):
rhs_flat = utility.dict_flatten(rhs)
sol = jax.pure_callback(lin_solve_fun, jnp.zeros(rhs_flat.shape, rhs_flat.dtype), mat, rhs_flat, free_dofs_flat, vmap_method='sequential')
return utility.reshape_as(sol, rhs)
# def lin_solve_callback_fun(mat, rhs, free_dofs_flat):
# rhs_flat = utility.dict_flatten(rhs)
# # linear_solver = lambda mat, rhs: jax.pure_callback(lin_solve_fun, jnp.zeros(rhs_flat.shape, rhs_flat.dtype), mat, rhs_flat, None, vmap_method='sequential')
# linear_solver = lambda mat, rhs: jax.pure_callback(lin_solve_fun, jnp.zeros(rhs_flat.shape, rhs_flat.dtype), mat, rhs_flat, free_dofs_flat, vmap_method='sequential')
# solve = lambda matvec, x: linear_solver(mat, x)
# transpose_solve = lambda vecmat, x: linear_solver(mat.T, x)
# matvec = lambda x: utility.reshape_as(mat @ utility.dict_flatten(x), x)
# sol = jax.lax.custom_linear_solve(matvec, rhs_flat, solve, transpose_solve)
# return utility.reshape_as(sol, rhs)
# Set up functions for adaptive load stepping loop
def continue_check(carry):
_, multiplier, increment, _, _, _ = carry
_continue_1 = jnp.logical_and(
multiplier < max_multiplier, increment > min_increment
)
jax.lax.cond(
jnp.logical_and(
multiplier < max_multiplier - min_increment, increment < min_increment
),
lambda: jax.debug.print(
"Adaptive load stepping could not converge; increment size below min_increment!"
),
lambda: None,
)
return _continue_1
def step(carry):
dofs0, multiplier, increment, load_step, settings, _ = carry
# Update multiplier
multiplier += increment
# Update boundary conditions
if verbose > -1:
if verbose > 0:
jax.debug.print("")
jax.debug.print("Multiplier: {x}", x=multiplier)
settings = multiplier_settings(settings, multiplier)
# Call newton solver
if (
path_dependent and implicit_diff_mode is not None
): # Add implicit differentiation for each load step
@implicit_diff.custom_root(
residual_fun,
tangent_fun,
lin_solve_callback_fun,
is_constrained,
True,
implicit_diff_mode,
)
def diffable_solve(dofs0, settings):
dofs, (needed_steps, res_norm_free_dofs, divergence) = solver(
dofs0, settings, static_settings, newton_tol=newton_tol, **kwargs
)
return dofs, (needed_steps.astype(float), res_norm_free_dofs, divergence.astype(float))
dofs, infos = diffable_solve(dofs0, settings)
needed_steps, res_norm_free_dofs, divergence = infos
divergence = divergence.astype(bool)
else: # Add implicit diff wrapper on the adaptive load stepping level
dofs, infos = solver(
dofs0, settings, static_settings, newton_tol=newton_tol, **kwargs
)
needed_steps, res_norm_free_dofs, divergence = infos
# Adaptive incrementation
multiplier = jnp.where(divergence, multiplier - increment, multiplier)
increment = jnp.where(
divergence,
0.5 * increment,
(1 + 0.5 * (target_num_newton_iter - needed_steps) / target_num_newton_iter)
* increment,
)
increment = jnp.where(increment > max_increment, max_increment, increment)
if isinstance(dofs, dict):
dofs = {
key: jnp.where(divergence, dofs0[key], dofs[key])
for key in dofs0.keys()
}
else:
dofs = jnp.where(divergence, dofs0, dofs)
# Limit to max_multiplier
increment = jnp.where(
multiplier + increment > max_multiplier,
max_multiplier - multiplier,
increment,
)
return (dofs, multiplier, increment, 1.0 * load_step, settings, res_norm_free_dofs)
# Use implicit diff wrappers to make it differentiable
if implicit_diff_mode is not None:
if not path_dependent: # Conservative problem
# Set Dirichlet conditions for derivative w.r.t. them
settings = multiplier_settings(settings, max_multiplier)
@implicit_diff.custom_root(
residual_fun,
tangent_fun,
lin_solve_callback_fun,
is_constrained,
True,
implicit_diff_mode,
)
def diffable_adaptive_load_stepping(dofs, settings):
(dofs, multiplier, increment, load_step, settings, res_norm_free_dofs) = \
jax.lax.while_loop(
continue_check, step, (dofs, 0.0, init_increment, 0, settings, 0.0)
)
return dofs, (multiplier, increment, load_step, res_norm_free_dofs)
dofs, (multiplier, increment, load_step, res_norm_free_dofs) = \
diffable_adaptive_load_stepping(dofs, settings)
return dofs, (multiplier, increment, load_step, settings, res_norm_free_dofs)
else: # Pathdependent problem; uses fori_loop with static limits for supporting reverse mode differentiation
# Set Dirichlet conditions for derivative w.r.t. them
settings = multiplier_settings(settings, max_multiplier)
def body_fn(i, carry):
def continue_check(carry):
_, multiplier, increment, _, _, divergence, stop = carry
_continue_1 = jnp.logical_and(
multiplier < max_multiplier, increment > min_increment
)
jax.lax.cond(
jnp.logical_and(
jnp.logical_and(
multiplier < max_multiplier, increment < min_increment
),
jnp.logical_not(stop),
),
lambda: jax.debug.print(
"Adaptive load stepping could not converge; increment size below min_increment!"
),
lambda: None,
)
return _continue_1
def step_extended(carry):
(dofs, multiplier, increment, load_step, settings, res_norm_free_dofs, stop) = carry
args = (dofs, multiplier, increment, load_step, settings, res_norm_free_dofs)
return step(args) + (False,)
def finish(carry):
(dofs, multiplier, increment, load_step, settings, res_norm_free_dofs, stop) = carry
return (dofs, multiplier, increment, load_step, settings, res_norm_free_dofs, True)
carry = jax.lax.cond(
continue_check(carry),
lambda x: step_extended(x),
lambda x: finish(x),
carry,
)
# ToDo: Verify accuracy of derivatives with finite differences.
return carry
init_state = (dofs, 0.0, init_increment, 0., settings, 0.0, False)
return jax.lax.fori_loop(0, max_load_steps, body_fn, init_state)
else: # No definition of implicit derivatives
return jax.lax.while_loop(
continue_check, step, (dofs, 0.0, init_increment, 0., settings, 0.0)
)
### Minimizers
[docs]@utility.jit_with_docstring(static_argnames=["static_settings", "**kwargs"])
def solve_nonlinear_minimization(dofs, settings, static_settings, **kwargs):
"""
Solves a nonlinear minimization problem using specified optimization methods.
This function wraps nonlinear minimization solvers provided by `jaxopt` to minimize a functional
and solve the given problem.
Args:
dofs (jnp.ndarray or dict): Degrees of freedom or initial guess for the solution.
settings (dict): Dictionary containing various settings and parameters required for assembling the problem.
static_settings (dict): Dictionary containing static settings such as solver type, verbose level, and variational schemes.
**kwargs (dict): Additional keyword arguments passed to the specific solver functions.
Returns:
jnp.ndarray: The optimized solution obtained from the selected solver.
Solver Types
- 'gradient descent' : Uses gradient descent for optimization.
- 'lbfgs' : Uses Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) algorithm for optimization.
- 'bfgs' : Uses Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm for optimization.
- 'nonlinear cg' : Uses nonlinear conjugate gradient method for optimization.
- 'gauss newton' : Uses Gauss-Newton method for optimization.
- 'levenberg marquart' : Uses Levenberg-Marquardt algorithm for optimization.
- Default : If solver name is not set or not available, uses 'lbfgs' as the default solver.
Notes:
- This function should just be called, if the variational scheme involves the definition of a functional that is to be minimized,
e.g. 'least square pde loss'. The modes 'gauss newton' and 'levenberg marquart' are an exeption, since they utilize the residual.
- The function conducts the assembler.integrate_functional and assembler.assemble_residual functions in order
to set up suitable optimization functions or residual functions, depending on what the solver needs.
- The current implementation does not support nodal imposition of DOFs.
"""
nodal_imposition = "nodal imposition" in static_settings["solution structure"]
assert (
not nodal_imposition
), "solver type 'minimize' does currently not support nodal imposition of DOFs."
# ToDo: impose boundary conditions and freeze dirichlet dofs
def functional(params):
return assembler.integrate_functional(params, settings, static_settings)
def residual_function(params):
return assembler.assemble_residual(params, settings, static_settings)
# Select minimizer
solver_name = static_settings["solver"]
match solver_name:
case "gradient descent":
solver = jaxopt.GradientDescent(functional, **kwargs)
sol = solver.run(dofs).params
case "lbfgs":
solver = jaxopt.LBFGS(functional, **kwargs)
sol = solver.run(dofs).params
case "bfgs":
solver = jaxopt.BFGS(functional, **kwargs)
sol = solver.run(dofs).params
case "nonlinear cg":
solver = jaxopt.NonlinearCG(functional, **kwargs)
sol = solver.run(dofs).params
case "gauss newton":
solver = jaxopt.GaussNewton(residual_function, **kwargs)
sol = solver.run(dofs).params
case "levenberg marquart":
assert not isinstance(
dofs, dict
), "solver 'levenberg marquart' does currently not support dicts as dofs"
def residual_function_flat(params):
return assembler.assemble_residual(
jnp.reshape(params, dofs.shape), settings, static_settings
).flatten()
solver = jaxopt.LevenbergMarquardt(residual_function_flat, **kwargs)
sol = jnp.reshape(solver.run(dofs.flatten()).params, dofs.shape)
case _:
solver = jaxopt.LBFGS(functional, **kwargs)
sol = solver.run(dofs).params
print(
"Solver name not set or not available in combination with this solver type. Using static_settings['solver name'] = 'lbfgs' as default."
)
return sol
### Root finders
[docs]@utility.jit_with_docstring(static_argnames=["static_settings", "**kwargs"])
def solve_linear(dofs, settings, static_settings, **kwargs):
"""
Solves a linear system using the specified backend and solver settings.
This function determines the appropriate linear solver based on the provided settings and forwards the
call to the selected solver function. It supports both JAX matrix-free solvers and external solvers
like PETSc, PARDISO, PyAMG, and Scipy using jax.pure_callback.
Args:
dofs (jnp.ndarray or dict): Degrees of freedom or initial guess for the solution.
settings (dict): Dictionary containing various settings and parameters required for assembling the problem.
static_settings (dict): Dictionary containing static settings such as solver backend, type of solver, and preconditioner.
**kwargs (dict): Additional keyword arguments passed to the specific solver functions.
Returns:
jnp.ndarray: The solution obtained from the selected linear solver.
Solver Backends:
- 'jax' : Uses JAX's matrix-free solvers.
- 'petsc' : Uses PETSc for solving linear systems.
- 'pardiso' : Uses PARDISO for solving linear systems.
- 'pyamg' : Uses PyAMG for solving linear systems.
- 'scipy' : Uses Scipy's sparse solvers for solving linear systems.
Notes:
- If `nodal imposition` is detected in the `static_settings`, the function imposes Dirichlet boundary conditions
and adjusts the degrees of freedom accordingly.
- The function assembles the residual and tangent matrix before solving the system in case an external solver is used.
- If the tangent matrix is dense and an external solver is used, it is converted to a sparse format.
- The function uses JAX's `pure_callback` to call external solvers and handle the solution.
"""
solver_backend = static_settings["solver backend"]
### Jax matrix-free solver
if solver_backend == "jax":
return linear_solve_jax(dofs, settings, static_settings, **kwargs)
### External linear solver
nodal_imposition = "nodal imposition" in static_settings["solution structure"]
# Impose nodal dofs
if nodal_imposition:
dirichlet_conditions = settings["dirichlet conditions"]
if isinstance(settings["dirichlet dofs"], (dict, FrozenDict)):
dirichlet_dofs_dict_flat = {
key: jnp.asarray(val).flatten()
for key, val in settings["dirichlet dofs"].items()
}
dirichlet_dofs_flat = jnp.concatenate(
list(dirichlet_dofs_dict_flat.values())
)
dofs = utility.mask_op(dofs, dirichlet_dofs_dict_flat, dirichlet_conditions)
else:
dirichlet_dofs_flat = jnp.asarray(
settings["dirichlet dofs"]
).flatten()
dofs = utility.mask_op(dofs, dirichlet_dofs_flat, dirichlet_conditions)
free_dofs_flat = jnp.invert(dirichlet_dofs_flat)
# Assembling
verbose = static_settings["verbose"]
solver = static_settings["solver"]
rhs = assembler.assemble_residual(dofs, settings, static_settings)
rhs = -utility.dict_flatten(rhs)
mat = assembler.assemble_tangent(dofs, settings, static_settings)
# If dense matrix and external solver, convert to sparse
if (
solver_backend in ("petsc", "pardiso", "pyamg", "scipy")
and type(mat) == jnp.ndarray
):
mat = sparse.bcoo_fromdense()
match solver_backend:
case "petsc":
n_fields = static_settings["number of fields"]
pc_type = static_settings["type of preconditioner"]
solve_fun = lambda a, b, c: linear_solve_petsc(
a, b, n_fields, solver, pc_type, verbose, free_dofs=c, **kwargs
)
case "pardiso":
solve_fun = lambda a, b, c: linear_solve_pardiso(
a, b, solver, verbose, free_dofs=c
)
case "pyamg":
pc_type = static_settings["type of preconditioner"]
solve_fun = lambda a, b, c: linear_solve_pyamg(
a, b, solver, pc_type, verbose, free_dofs=c, **kwargs
)
case "scipy":
solve_fun = lambda a, b, c: linear_solve_scipy(
a, b, solver, verbose, free_dofs=c
)
# Prepare callback
result_shape_dtype = jax.ShapeDtypeStruct(shape=rhs.shape, dtype=rhs.dtype)
# Compose solution dofs
if nodal_imposition:
sol = jax.pure_callback(solve_fun, result_shape_dtype, mat, rhs, free_dofs_flat, vmap_method='sequential')
if isinstance(settings["dirichlet dofs"], (dict, FrozenDict)):
free_dofs_dict = {
key: jnp.invert(jnp.asarray(val))
for key, val in settings["dirichlet dofs"].items()
}
sol = utility.reshape_as(sol, dofs)
return utility.mask_op(dofs, free_dofs_dict, sol)
else:
return utility.mask_op(dofs, free_dofs_flat, sol)
else:
sol = jax.pure_callback(solve_fun, result_shape_dtype, mat, rhs, None, vmap_method='sequential')
return utility.reshape_as(sol, dofs)
[docs]@utility.jit_with_docstring(static_argnames=["static_settings", "**kwargs"])
def solve_diagonal_linear(dofs, settings, static_settings, **kwargs):
"""
Solves a linear system assuming the tangent matrix is diagonal.
This function solves the linear system by leveraging the assumption that the tangent matrix is diagonal,
which simplifies the solution process. It supports nodal imposition of Dirichlet boundary conditions and
handles the assembly of the residual and diagonal tangent matrix.
If the tangent matrix is not diagonal, it will produce a wrong diagonal of the tangent!
Args:
dofs (jnp.ndarray or dict): Degrees of freedom or initial guess for the solution.
settings (dict): Dictionary containing various settings and parameters required for assembling the problem.
static_settings (dict): Dictionary containing static settings such as solution structure and solver backend.
**kwargs (dict): Additional keyword arguments passed to the solver function.
Returns:
sol (jnp.ndarray): The solution obtained by solving the linear system assuming a diagonal tangent matrix.
Notes:
- If `nodal imposition` is detected in the `static_settings`, the function imposes Dirichlet boundary conditions
and adjusts the degrees of freedom accordingly.
- The function assembles the residual and diagonal tangent matrix before solving the system.
- The solution process involves element-wise division of the residual by the diagonal elements of the tangent matrix.
"""
nodal_imposition = "nodal imposition" in static_settings["solution structure"]
# Impose nodal dofs
if nodal_imposition:
dirichlet_conditions = utility.dict_flatten(settings["dirichlet conditions"])
if isinstance(settings["dirichlet dofs"], (dict, FrozenDict)):
dirichlet_dofs_dict_flat = {
key: jnp.asarray(val).flatten()
for key, val in settings["dirichlet dofs"].items()
}
dirichlet_dofs_flat = jnp.concatenate(
list(dirichlet_dofs_dict_flat.vlaues())
)
dofs = utility.mask_op(
dofs, dirichlet_dofs_dict_flat, dirichlet_conditions
)
else:
dirichlet_dofs_flat = jnp.asarray(
settings["dirichlet dofs"]
).flatten()
dofs = utility.mask_op(dofs, dirichlet_dofs_flat, dirichlet_conditions)
free_dofs_flat = jnp.invert(dirichlet_dofs_flat)
# Assembling
rhs = -assembler.assemble_residual(dofs, settings, static_settings)
diag_mat = assembler.assemble_tangent(dofs, settings, static_settings).data
# Delete rows
rhs = utility.dict_flatten(rhs)
if nodal_imposition:
# Delete rows
rhs = rhs[free_dofs_flat]
diag_mat = diag_mat[free_dofs_flat]
# Solve while assuming a diagonal tangent
sol = jnp.multiply((1 / diag_mat), rhs)
# Compose solution dofs
if nodal_imposition:
if isinstance(settings["dirichlet dofs"], (dict, FrozenDict)):
free_dofs_dict = {
key: jnp.invert(jnp.asarray(val))
for key, val in settings["dirichlet dofs"].items()
}
sol = utility.reshape_as(sol, dofs)
return utility.mask_op(dofs, free_dofs_dict, sol)
else:
return utility.mask_op(dofs, free_dofs_flat, sol)
else:
return utility.reshape_as(sol, dofs)
[docs]@utility.jit_with_docstring(static_argnames=["static_settings", "**kwargs"])
def solve_newton(
dofs, settings, static_settings, newton_tol=1e-8, maxiter=30, **kwargs
):
"""
Solves a nonlinear system using the Newton-Raphson method.
This function is a wrapper for the damped Newton method with a damping coefficient of 1.0,
effectively performing standard Newton-Raphson iterations.
Args:
dofs (jnp.ndarray or dict): Degrees of freedom or initial guess for the solution.
settings (dict): Dictionary containing various settings and parameters required for assembling the problem.
static_settings (dict): Dictionary containing static settings such as solution structure and solver backend.
newton_tol (float, optional): Tolerance for the Newton method convergence criterion. Default is 1e-8.
maxiter (int, optional): Maximum number of iterations for the Newton method. Default is 30.
**kwargs (dict): Additional keyword arguments passed to the solver function.
Returns:
tuple: A tuple containing the following elements:
- sol (jnp.ndarray): The solution obtained by solving the nonlinear system using the Newton method.
- infos (tuple): Additional information about the solution process, including:
- num_iterations (int): The number of iterations performed.
- residual_norm (float): The norm of the residual at the solution.
- diverged (bool): Flag indicating whether the method diverged.
"""
return solve_damped_newton(
dofs, settings, static_settings, newton_tol, 1.0, maxiter
)
[docs]@utility.jit_with_docstring(static_argnames=["static_settings", "**kwargs"])
def solve_damped_newton(
dofs,
settings,
static_settings,
newton_tol=1e-8,
damping_coefficient=0.8,
maxiter=30,
**kwargs,
):
"""
Solves a nonlinear system using the damped Newton method.
This function performs damped Newton iterations to solve a nonlinear system,
with support for nodal imposition of Dirichlet boundary conditions specified as a boolean tuple-tree in static_settings['dirichlet dofs'].
In this function the information needed for solver.damped_newton is prepared and the function is then called.
Args:
dofs (jnp.ndarray or dict): Degrees of freedom or initial guess for the solution.
settings (dict): Dictionary containing various settings and parameters required for assembling the problem.
static_settings (dict): Dictionary containing static settings such as solution structure and solver backend.
newton_tol (float, optional): Tolerance for the Newton method convergence criterion. Default is 1e-8.
damping_coefficient (float, optional): Damping coefficient for the Newton updates. Default is 0.8.
maxiter (int, optional): Maximum number of iterations for the Newton method. Default is 30.
**kwargs (dict): Additional keyword arguments passed to the solver function.
Returns:
tuple: A tuple containing the following elements:
- sol (jnp.ndarray): The solution obtained by solving the nonlinear system using the damped Newton method.
- infos (tuple): Additional information about the solution process, including:
- num_iterations (int): The number of iterations performed.
- residual_norm (float): The norm of the residual at the solution.
- diverged (bool): Flag indicating whether the method diverged.
"""
residual_fun = lambda x_i: assembler.assemble_residual(
x_i, settings, static_settings
)
lin_solve_fun = lambda x_i: solve_linear(x_i, settings, static_settings, **kwargs)
nodal_imposition = "nodal imposition" in static_settings["solution structure"]
if nodal_imposition:
# Impose Dirichlet boundaries. Dirichlet dofs has to be concrete, therefore it is passed in static_settings as tuple of tuples
if isinstance(settings["dirichlet dofs"], (dict, FrozenDict)):
free_dofs_flat = {
key: jnp.invert(jnp.asarray(val).flatten())
for key, val in settings["dirichlet dofs"].items()
}
else:
free_dofs_flat = jnp.invert(
jnp.asarray(settings["dirichlet dofs"]).flatten()
)
else:
free_dofs_flat = jnp.ones(utility.dict_flatten(dofs).shape, dtype=jnp.bool_)
verbose = static_settings["verbose"]
return damped_newton(
dofs,
residual_fun,
lin_solve_fun,
free_dofs_flat,
newton_tol,
maxiter,
damping_coefficient,
verbose=verbose,
)
def damped_newton(
dofs_0,
residual_fun,
lin_solve_fun,
free_dofs,
newton_tol,
maxiter,
damping_coefficient,
verbose=1,
):
"""
Performs damped Newton iterations to solve a nonlinear system.
This function implements the damped Newton method for solving a nonlinear system. It updates
the solution iteratively based on the residual and chosen linear solver for the Newton step,
with a damping coefficient to control the step size.
Args:
dofs_0 (jnp.ndarray or dict): Initial guess for the degrees of freedom.
residual_fun (callable): Function to compute the residual of the system.
lin_solve_fun (callable): Function to solve the linearized system for the Newton step.
free_dofs (jnp.ndarray or None): Boolean array indicating the free degrees of freedom.
If None, all degrees of freedom are free.
newton_tol (float): Tolerance for the Newton method convergence criterion.
maxiter (int): Maximum number of iterations for the Newton method.
damping_coefficient (float): Damping coefficient for the Newton updates.
Returns:
tuple: A tuple containing the following elements:
- sol (jnp.ndarray): The solution obtained by solving the nonlinear system using the damped Newton method.
- infos (tuple): Additional information about the solution process, including:
- num_iterations (int): The number of iterations performed.
- residual_norm (float): The norm of the residual at the solution.
- diverged (bool): Flag indicating whether the method diverged.
"""
def step(carry):
dofs_i, itt, _, res_norm_old, _ = carry
# Update formula of newton scheme
delta_x_i = lin_solve_fun(dofs_i)
# If free_dofs is None, apply update to all dofs, otherwise apply masking
if free_dofs is not None:
# Apply damping to the Newton step on free dofs
delta_x_i = utility.mask_op(
delta_x_i, free_dofs, mode="apply", ufunc=lambda x: damping_coefficient * x
)
# Update free dofs with the damped step
dofs_i = utility.mask_op(dofs_i, free_dofs, delta_x_i, "add")
# Set boundary conditions for dirichlet_dofs
if isinstance(free_dofs, dict):
dirichlet_dofs = {key: jnp.invert(val) for (key, val) in free_dofs.items()}
else:
dirichlet_dofs = jnp.invert(free_dofs)
dofs_i = utility.mask_op(dofs_i, dirichlet_dofs, delta_x_i, "set")
else:
# If free_dofs is None, apply the update to all dofs directly
dofs_i += damping_coefficient * delta_x_i
# Compute residual for next step as convergence test
residual = residual_fun(dofs_i)
if free_dofs is not None:
residual_flat = utility.dict_flatten(utility.mask_select(residual, free_dofs))
else:
residual_flat = utility.dict_flatten(residual) # Use full residual if no mask
res_norm = jnp.linalg.norm(residual_flat)
not_stop = jnp.where(res_norm > newton_tol, True, False)
def report():
if verbose > 0:
jax.debug.print(
"Residual after Newton iteration {x}: {y}", x=itt + 1, y=res_norm
)
if verbose > 1:
jax.debug.print("")
# Check for divergence
divergence = jnp.where(
jnp.logical_and(res_norm / res_norm_old > 10, itt > 1), True, False
)
# If nan or inf in residual, set divergence
divergence = jnp.where(
jnp.any(jnp.logical_or(jnp.isnan(residual_flat), jnp.isinf(residual_flat))),
True,
divergence,
)
return jnp.invert(divergence), divergence
def stop_newton():
jax.debug.print("Warning: Newton scheme could not converge!")
return False, True
next_step, divergence = jax.lax.cond(itt < maxiter, report, stop_newton)
return (
dofs_i,
itt + 1,
jnp.logical_and(not_stop, next_step),
res_norm,
divergence,
)
def convergence_check(carry):
_, _, not_stop, _, _ = carry
return not_stop
# Start Newton iteration loop
sol, load_steps, _, res_norm, divergence = lax.while_loop(
convergence_check, step, (dofs_0, 0, True, 0.0, False)
)
return (sol, (load_steps, res_norm, divergence))
### Linear solvers for different backends
[docs]@utility.jit_with_docstring(static_argnames=["static_settings", "**kwargs"])
def linear_solve_jax(dofs, settings, static_settings, **kwargs):
"""
Solves a linear system of equations using JAX's matrix free solvers.
This function performs a linear solve using different itterative solvers,
optionally imposing Dirichlet boundary conditions and preconditioning and
using different matrix or Hessian vector product (HVP) methods.
Args:
dofs (jnp.ndarray): Initial degrees of freedom.
settings (dict): Dictionary of problem settings.
static_settings (dict): Dictionary of static settings that do not change during iterations.
**kwargs: Additional keyword arguments for the solver.
Returns:
jnp.ndarray: Solution degrees of freedom after solving the linear system.
Hessian Vector Product (HVP) Types:
- 'fwdrev': Forward and reverse mode differentiation.
- 'revrev': Reverse mode differentiation (only for symmetric matrices).
- 'assemble': Assembles the tangent matrix explicitly (not supported with nodal imposition, then call e.g. PetSc).
- 'linearize': Uses JAX's linearize function.
- Default: Uses 'fwdrev' if not specified or if an unsupported type is provided.
Solver Types:
- 'cg': Conjugate Gradient.
- 'normal cg': Normal Conjugate Gradient.
- 'gmres': Generalized Minimal Residual Method.
- 'bicgstab': BiConjugate Gradient Stabilized.
- 'lu': LU Decomposition.
- 'cholesky': Cholesky Decomposition (converts tangent to dense mode; not supported with nodal imposition).
- 'qr': QR Decomposition (not supported with nodal imposition).
- 'jacobi': Jacobi Method.
- Default: Uses 'bicgstab' if not specified or if an unsupported type is provided.
Notes:
- Dirichlet boundary conditions are imposed if 'nodal imposition' is specified as the solution structure.
- The itterative solver can be preconditioned with 'jacobi'.
- When using 'assemble' HVP type, the function will explicitly assemble the tangent matrix.
"""
assert not isinstance(
dofs, dict
), "dofs as dict are currently not implemented in linear_solve_jax."
solver = static_settings["solver"]
hvp_type = static_settings["hvp type"]
nodal_imposition = "nodal imposition" in static_settings["solution structure"]
if nodal_imposition:
# Impose Dirichlet boundaries. Dirichlet dofs has to be concrete, therefore it is passed in static_settings as tuple of tuples
dirichlet_dofs_flat = jnp.asarray(settings["dirichlet dofs"]).flatten()
dirichlet_conditions = settings["dirichlet conditions"].flatten()
free_dofs_flat = jnp.invert(dirichlet_dofs_flat)
# Impose nodal dofs
flat_dofs = dofs.flatten()
idx = jnp.arange(flat_dofs.shape[0])[dirichlet_dofs_flat]
flat_dofs = flat_dofs.at[idx].set(dirichlet_conditions[idx])
dofs = flat_dofs.reshape(dofs.shape)
free_dofs = dofs.flatten()[free_dofs_flat]
def residual_fun(x):
flat_dofs = dofs.flatten()
idx = jnp.arange(flat_dofs.shape[0])[free_dofs_flat]
flat_dofs = flat_dofs.at[idx].set(x)
return assembler.assemble_residual(
flat_dofs.reshape(dofs.shape), settings, static_settings
).flatten()[free_dofs_flat]
def diag_assemble_fun(x):
flat_dofs = dofs.flatten()
idx = jnp.arange(flat_dofs.shape[0])[free_dofs_flat]
flat_dofs = flat_dofs.at[idx].set(x)
return assembler.assemble_tangent_diagonal(
flat_dofs.reshape(dofs.shape), settings, static_settings
).flatten()[free_dofs_flat]
rhs = -residual_fun(dofs.flatten()[free_dofs_flat])
else:
free_dofs = dofs
residual_fun = lambda x: assembler.assemble_residual(
x, settings, static_settings
)
mat_assemble_fun = lambda x: assembler.assemble_tangent(
x, settings, static_settings
)
diag_assemble_fun = lambda x: assembler.assemble_tangent_diagonal(
x, settings, static_settings
)
rhs = -residual_fun(dofs)
# Type of Hessian vector product
match hvp_type:
case "fwdrev":
def hvp_fwdrev(v):
return jax.jvp(residual_fun, (free_dofs,), (v,))[1]
hvp = hvp_fwdrev
case "revrev": # only works for symmetric matrices
def hvp_revrev(v):
return jax.jacrev(lambda x: jnp.vdot(residual_fun(x), v))(free_dofs)
hvp = hvp_revrev
case "assemble":
assert (
not nodal_imposition
), "hvp type 'assemble' is currently not supported for nodal imposition of DOFs."
tangent = mat_assemble_fun(free_dofs)
def hvp_assembled(v):
mapped = tangent @ v.flatten()
return jnp.reshape(mapped, v.shape)
hvp = hvp_assembled
case "linearize":
def hvp_linearized(v):
(_, linearized) = jax.linearize(residual_fun, free_dofs)
return linearized(v)
hvp = hvp_linearized
case _:
hvp = hvp_fwdrev
print(
"Type of hessian vector product has not been set or is not available. Using static_settings['matrix-free'] = 'revrev' as default."
)
# Preconditioning for itterative solvers
if (
solver == "cg"
or solver == "bicgstab"
or solver == "normal cg"
or solver == "gmres"
):
try:
precond_type = static_settings["type of preconditioner"]
match precond_type:
case "jacobi":
# Inversion of diagonal part of tangent matrix as preconditioner
M = 1 / diag_assemble_fun(free_dofs)
def preconditioner(v):
preconditioned = jnp.multiply(M, v.flatten())
return jnp.reshape(preconditioned, free_dofs.shape)
case _:
def preconditioner(v):
return v
print(
"Wrong preconditioner keyword. Continue without preconditioner."
)
except KeyError:
preconditioner = None
pass
# Select linear solver (itterative or direct)
match solver:
case "cg":
(sol, _) = jax.scipy.sparse.linalg.cg(hvp, rhs, M=preconditioner, **kwargs)
case "normal cg":
sol = jaxopt.linear_solve.solve_normal_cg(hvp, rhs, **kwargs)
case "gmres":
(sol, _) = jax.scipy.sparse.linalg.gmres(
hvp, rhs, M=preconditioner, **kwargs
)
case "bicgstab":
(sol, _) = jax.scipy.sparse.linalg.bicgstab(
hvp, rhs, M=preconditioner, **kwargs
)
case "lu":
sol = jaxopt.linear_solve.solve_lu(hvp, rhs)
case "cholesky":
assert (
not nodal_imposition
), "solver 'cholesky' is currently not supported for nodal imposition of DOFs."
chol, lower = jax.scipy.linalg.cho_factor(mat_assemble_fun(dofs).todense())
sol = jnp.reshape(
jax.scipy.linalg.cho_solve((chol, lower), rhs.flatten()), dofs.shape
)
case "qr":
assert (
not nodal_imposition
), "solver 'qr' is currently not supported for nodal imposition of DOFs."
bcoo_tangent = mat_assemble_fun(dofs)
bcsr_tangent = sparse.BCSR.from_bcoo(bcoo_tangent).sum_duplicates(
nse=bcoo_tangent.nse
)
sol = sparse.linalg.spsolve(
bcsr_tangent.data,
bcsr_tangent.indices,
bcsr_tangent.indptr,
rhs.flatten(),
).reshape(dofs.shape)
case "jacobi":
diag = diag_assemble_fun(free_dofs)
sol = jacobi_method(hvp, diag, free_dofs, rhs, **kwargs)
case _:
(sol, _) = jax.scipy.sparse.linalg.bicgstab(
hvp, rhs, M=preconditioner, **kwargs
)
print(
"Solver name not set or not available in combination with this solver type. Using static_settings['solver'] = 'bicgstab' as default."
)
if static_settings["verbose"] >= 1:
residual = rhs - hvp(sol)
jax.debug.print(
"The relative residual is: {value}",
value=jnp.linalg.norm(residual) / jnp.linalg.norm(rhs),
)
# Compose solution dofs
if nodal_imposition:
flat_dofs = dofs.flatten()
flat_dofs = flat_dofs.at[idx].set(dirichlet_conditions[idx])
idx = jnp.arange(flat_dofs.shape[0])[free_dofs_flat]
flat_dofs = flat_dofs.at[idx].set(sol)
dofs = flat_dofs.reshape(dofs.shape)
return dofs
else:
return sol
[docs]def scipy_assembling(tangent_with_duplicates, verbose, free_dofs):
"""
Convert a JAX BCOO matrix to a SciPy CSR matrix while summing duplicates.
This function converts a JAX BCOO matrix, which may contain duplicate entries,
into a SciPy CSR matrix. It optionally deletes rows and columns corresponding
to Dirichlet degrees of freedom.
Args:
tangent_with_duplicates (jax.experimental.sparse.BCOO): The input JAX BCOO matrix.
verbose (int): Verbosity level. If >= 2, timing information is printed.
free_dofs (array or None): Boolean array indicating which degrees of freedom are free.
If not None, rows and columns corresponding to Dirichlet DOFs
are removed from the matrix.
Returns:
scipy.sparse.csr_matrix: The converted and possibly reduced CSR matrix.
"""
if verbose >= 2:
start = time.time()
data = tangent_with_duplicates.data
indices = tangent_with_duplicates.indices
rows = indices[:, 0]
cols = indices[:, 1]
# This is currently done on CPU and seems to be one of the computational bottlenecks when using GPU
tangent_coo = scp.sparse.coo_matrix(
(data, (rows, cols)), shape=tangent_with_duplicates.shape
)
tangent_csr = scp.sparse.csr_matrix(tangent_coo)
# Row deletion for Dirichlet-DOFs
if free_dofs is not None:
# Deleting rows and columns
tangent_csr = tangent_csr[:, free_dofs]
tangent_csr = tangent_csr[free_dofs]
if verbose >= 2:
print("Time for summing duplicates: ", time.time() - start)
return tangent_csr
[docs]def linear_solve_petsc(mat, rhs, n_fields, solver, pc_type, verbose, free_dofs, tol=1e-8, **kwargs):
"""
Solve a linear system using the PETSc solver (requires PETSc and petsc4py to be installed).
This function solves a linear system using PETSc, with options for different
solvers and preconditioners. The input matrix is first converted to a SciPy
CSR matrix, and rows and columns corresponding to Dirichlet DOFs are optionally
removed.
Args:
mat (jax.experimental.sparse.BCOO): The input matrix in JAX BCOO format.
rhs (jnp.ndarray): The (reduced/free) right-hand side vector.
n_fields (int): The number of fields.
solver (str): The type of solver to use.
pc_type (str): The type of preconditioner to use.
verbose (int): Verbosity level. If >= 1, timing and solver information is printed.
free_dofs (array or None): Boolean array indicating which degrees of freedom are free.
tol (float): The relative tolerance for the solver.
**kwargs: Additional keyword arguments for the solver.
Returns:
jax.numpy.array: The solution vector.
Notes:
- The solver settings can also be set from the command line. See PETSc and petsc4py documentation.
"""
try:
import petsc4py
petsc4py.init(sys.argv)
from petsc4py import PETSc
except ModuleNotFoundError:
print("Linear solver requires petsc and petsc4py")
if free_dofs is not None:
reduced_rhs = rhs[free_dofs]
n_dofs = reduced_rhs.shape[0]
else:
reduced_rhs = rhs
n_dofs = rhs.shape[0]
# Transform matrix to csr format and sum duplicates
tangent_csr = scipy_assembling(mat, verbose, free_dofs)
if verbose >= 2:
start = time.time()
# Load to petsc
mat_petsc = PETSc.Mat().createAIJ(
size=tangent_csr.shape,
csr=(
tangent_csr.indptr.astype(PETSc.IntType),
tangent_csr.indices.astype(PETSc.IntType),
tangent_csr.data,
),
)
mat_petsc.setFromOptions()
mat_petsc.setBlockSize(n_fields)
if verbose >= 2:
print("To PETSc transformation time: ", time.time() - start)
print("Matrix infos: ")
print()
print(mat_petsc.getInfo())
start = time.time()
# Initialization of right hand side and solution vector
b = PETSc.Vec().createSeq(n_dofs)
b.setFromOptions()
b.setArray(np.array(reduced_rhs))
x = PETSc.Vec().createSeq(n_dofs)
x.setFromOptions()
# Solver settings
rtol = tol
ksp = PETSc.KSP().create()
ksp.setTolerances(rtol=rtol, **kwargs)
ksp.setOperators(mat_petsc)
ksp.setType(solver)
ksp.setFromOptions()
ksp.setConvergenceHistory()
ksp.getPC().setType(pc_type)
ksp.getPC().setFromOptions()
# Monitoring
if verbose >= 2:
print("Iteration Residual")
def monitor(ksp, its, rnorm):
print("%5d %20.15g" % (its, rnorm))
ksp.setMonitor(monitor)
# Solving
ksp.solve(b, x)
# Fill solution vector with computed values
if free_dofs is not None:
sol = jnp.zeros(rhs.shape)
sol = sol.at[free_dofs].set(x.getArray())
else:
sol = jnp.array(x.getArray())
if verbose >= 2:
residual = mat_petsc * x - b
print("Itterative linear solver time: ", time.time() - start)
print("Number of iterations: ", ksp.getIterationNumber())
print("Type: ", ksp.getType())
print("Tolerances: ", ksp.getTolerances())
print(f"The relative residual is: {residual.norm() / (b.norm() + 1e-12)}.")
return sol
[docs]def linear_solve_pardiso(mat, rhs, solver, verbose, free_dofs):
"""
Solve a linear system using the PARDISO solver (requires Intel MKL and pypardiso('lu') or sparse_dot_mkl('qr') to be installed).
This function solves a linear system using PARDISO, with options for different
solver types. The input matrix is first converted to a SciPy CSR matrix, and
rows and columns corresponding to Dirichlet DOFs are optionally removed.
Args:
mat (jax.experimental.sparse.BCOO): The input matrix in JAX BCOO format.
rhs (jnp.ndarray): The (reduced/free) right-hand side vector.
solver (str): The type of solver to use ('lu' or 'qr').
verbose (int): Verbosity level. If >= 2, timing information is printed.
free_dofs (array or None): Boolean array indicating which degrees of freedom are free.
Returns:
jax.numpy.array: The solution vector.
"""
# Transform matrix to csr format and sum duplicates
tangent_csr = scipy_assembling(mat, verbose, free_dofs)
if verbose >= 2:
start = time.time()
# Prepare right hand side
if free_dofs is not None:
b = np.asarray(rhs[free_dofs])
else:
b = np.asarray(rhs)
if solver == "lu":
try:
import pypardiso
except ModuleNotFoundError:
print("Linear solver requires the installation of pypardiso.")
# ToDo: make use of symmetries and other settings available#, msglvl=verbose, iparm=iparm
# pypardiso_solver = pypardiso.PyPardisoSolver(mtype=11) # spd: 2
# x = pypardiso.spsolve(tangent_csr, b, solver=pypardiso_solver)
x = pypardiso.spsolve(tangent_csr, b)
elif solver == "qr":
try:
import sparse_dot_mkl
except ModuleNotFoundError:
print("Linear solver requires the installation of sparse_dot_mkl.")
x = sparse_dot_mkl.sparse_qr_solve_mkl(tangent_csr, b)
else:
assert False, "Type of solver not supported. Choose 'lu' or 'qr'"
# Fill solution vector with computed values
if free_dofs is not None:
sol = jnp.zeros(rhs.shape)
sol = sol.at[free_dofs].set(x)
else:
sol = jnp.array(x)
if verbose >= 2:
residual = b - tangent_csr * x
print(
f"The relative residual after linear solve is: {np.linalg.norm(residual) / (np.linalg.norm(b) + 1e-12)}."
)
print("Linear solver time: ", time.time() - start)
return sol
[docs]def linear_solve_pyamg(mat, rhs, solver, pc_type, verbose, free_dofs, **kwargs):
"""
Solve a linear system using the PyAMG solver (requires pyamg to be installed).
This function solves a linear system using PyAMG, an algebraic multi-grid solver with options
for different solvers and preconditioners. The input matrix is first converted to a SciPy
CSR matrix, and rows and columns corresponding to Dirichlet DOFs are optionally
removed.
Args:
mat (jax.experimental.sparse.BCOO): The input matrix in JAX BCOO format.
rhs (jnp.ndarray): The (reduced/free) right-hand side vector.
solver (str): The type of solver to use ('cg', 'bcgs', or 'gmres').
pc_type (str): The type of preconditioner to use ('ruge stuben' or 'smoothed aggregation').
verbose (int): Verbosity level. If >= 1, timing and solver information is printed.
free_dofs (array or None): Boolean array indicating which degrees of freedom are free.
**kwargs: Additional keyword arguments for the solver.
Returns:
jax.numpy.array: The solution vector.
"""
# Transform matrix to csr format and sum duplicates
pyamg_tangent = scipy_assembling(mat, verbose, free_dofs)
if verbose >= 2:
start = time.time()
# Set up solver
try:
import pyamg
except ModuleNotFoundError:
print("Linear solver requires the installation of pyamg.")
if pc_type == "ruge stuben":
ml = pyamg.ruge_stuben_solver(A=pyamg_tangent)
elif pc_type == "smoothed aggregation":
ml = pyamg.smoothed_aggregation_solver(A=pyamg_tangent)
elif pc_type == "root node":
ml = pyamg.rootnode_solver(A=pyamg_tangent)
elif pc_type == "pairwise":
ml = pyamg.pairwise_solver(A=pyamg_tangent)
else:
assert (
False
), "Type of preconditioner not supported. Choose 'ruge stuben' or 'smoothed aggregation', 'root node' or 'pairwise'"
if verbose >= 2:
print(ml)
print("Time for setting up multigrid preconditioner: ", time.time() - start)
start = time.time()
# Prepare right hand side
if free_dofs is not None:
b = np.asarray(rhs[free_dofs])
else:
b = np.asarray(rhs)
residuals = []
# Solving
if solver == None:
x = ml.solve(b, tol=1e-8, residuals=residuals, cycle='W')
elif solver == "cg":
x = ml.solve(b, accel=scp.sparse.linalg.cg, tol=1e-8, residuals=residuals, cycle='W')
elif solver == "bcgs":
x = ml.solve(b, accel=scp.sparse.linalg.bicgstab, tol=1e-8, residuals=residuals, cycle='W')
elif solver == "gmres":
x = ml.solve(b, accel=scp.sparse.linalg.gmres, tol=1e-8, residuals=residuals, cycle='W')
else:
assert False, "Type of solver not supported. Choose None, 'cg', 'bcgs' or 'gmres'"
# Fill solution vector with computed values
if free_dofs is not None:
sol = jnp.zeros(rhs.shape)
sol = sol.at[free_dofs].set(x)
else:
sol = jnp.array(x)
if verbose >= 2:
residual = b - pyamg_tangent * x
print(
f"The relative residual is: {np.linalg.norm(residual) / (np.linalg.norm(b) + 1e-12)}."
)
print("Itterative linear solver time: ", time.time() - start)
if verbose >= 3:
import matplotlib.pyplot as plt
plt.semilogy(residuals/residuals[0], 'o-')
plt.xlabel('iterations')
plt.ylabel('normalized residual')
plt.show()
return sol
[docs]def linear_solve_scipy(mat, rhs, solver, verbose, free_dofs):
"""
Solves a linear system using a specified SciPy solver.
Args:
mat (bcoo): JAX BCOO matrix representing the system's tangent matrix.
rhs (jnp.ndarray): The (reduced/free) right-hand side vector.
solver (str): Type of solver to use. Options are 'lapack' or 'umfpack'.
verbose (int): Verbosity level for logging.
free_dofs (jnp.ndarray): Boolean array indicating free degrees of freedom for Dirichlet boundary conditions.
Returns:
sol (jnp.ndarray): Solution vector to the linear system.
"""
# Transform matrix to csr format and sum duplicates
tangent_csr = scipy_assembling(mat, verbose, free_dofs)
if verbose >= 2:
start = time.time()
# Prepare right hand side
if free_dofs is not None:
b = rhs[free_dofs]
else:
b = rhs
if solver == "lapack":
x = scp.sparse.linalg.spsolve(tangent_csr, b)
elif solver == "umfpack":
x = scp.sparse.linalg.spsolve(tangent_csr, b, use_umfpack=True)
else:
assert False, "Type of solver not supported. Choose 'lapack' or 'umfpack'"
# Fill solution vector with computed values
if free_dofs is not None:
sol = jnp.zeros(rhs.shape)
sol = sol.at[free_dofs].set(x)
else:
sol = jnp.array(x)
if verbose >= 2:
residual = b - tangent_csr * x
print(
f"The relative residual is: {np.linalg.norm(residual) / (np.linalg.norm(b) + 1e-12)}."
)
print("Direct solver time: ", time.time() - start)
return sol
### Iterative solvers/smoothers
[docs]def jacobi_method(hvp_fun, diag, x_0, rhs, tol=1e-6, atol=1e-6, maxiter=1000):
"""
Solve Ax = b using Jacobi iterations (experimental).
Args:
hvp_fun (function): Hessian vector product function.
diag (jnp.ndarray): Diagonal of the Hessian matrix.
x_0 (jnp.ndarray): Initial guess for the solution.
rhs (jnp.ndarray): Right-hand side vector of the linear system.
tol (float): Relative tolerance for convergence.
atol (float): Absolute tolerance for convergence.
maxiter (int): Maximum number of iterations.
Returns:
array: Solution vector x.
"""
assert not isinstance(
x_0, dict
), "dofs as dict are currently not implemented in jacobi_method."
# Initialization
inverse_diag = 1 / diag
scaled_rhs = jnp.multiply(inverse_diag, rhs)
rhs_squared = jnp.vdot(rhs, rhs)
# Itterations
def body_fun(value):
x_k, k = value
dx_k1 = scaled_rhs - jnp.multiply(inverse_diag, hvp_fun(x_k))
x_k1 = x_k + dx_k1
return (x_k1, k + 1)
def cond_fun(value):
x_k, k = value
r_k = rhs - hvp_fun(x_k)
r_k_squared = jnp.vdot(r_k, r_k)
return k < maxiter # and r_k_squared > jnp.max(tol**2 * rhs_squared, atol**2)
x_final, *_ = lax.while_loop(cond_fun, body_fun, (x_0, 0))
return x_final
[docs]def damped_jacobi_relaxation(hvp_fun, diag, x_0, rhs, damping_factor=0.3333333, **kwargs):
"""
Damped Jacobi smoother (experimental).
Args:
hvp_fun (callable): Hessian vector product function.
diag (jnp.ndarray): Diagonal of the Hessian matrix.
x_0 (jnp.ndarray): Initial guess for the solution.
rhs (jnp.ndarray): Right-hand side vector of the linear system.
damping_factor (float): Damping factor for the iterations (<=0.5 guarantees a good smoother).
**kwargs (dict): Additional keyword arguments for customization.
Returns:
array: Solution vector x.
"""
assert not isinstance(
x_0, dict
), "dofs as dict are currently not implemented in damped_jacobi_relaxation."
# Initialization
inverse_diag = 1 / diag
scaled_rhs = jnp.multiply(inverse_diag, rhs)
# Itterations
def body_fun(x_k, idx):
dx_k1 = scaled_rhs - jnp.multiply(inverse_diag, hvp_fun(x_k))
x_k1 = x_k + damping_factor * dx_k1
return x_k1, None
iterations = 20
x_final, *_ = lax.scan(body_fun, init=x_0, xs=None, length=iterations)
return x_final