Source code for autopdex.mesher

# mesher.py
# Copyright (C) 2024 Tobias Bode
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

"""
Module for generation of meshes. 

Currently only support first order meshes on quadrilateral and hexahedral domains.
For complex meshes, consider using e.g. GMSH. The necessary information that is needed 
by autopdex is the connectivity of the mesh, i.e. a list of node IDs that form each element 
and the coordinates of the nodes in the order of the IDs.
"""

import jax
import jax.numpy as jnp
import numpy as np

from autopdex.utility import jit_with_docstring

[docs]@jit_with_docstring(static_argnames=["n_elements", "element_type", "order"]) def structured_mesh(n_elements, vertices, element_type, order=1): """ Generate a structured 2D or 3D mesh over a quadrilateral or hexahedral domain, with an optional subdivision into simplex (triangular/tetrahedral) elements. Parameters ---------- n_elements : tuple of int For 2D, a tuple (nx, ny) specifying the number of elements along the x- and y-directions. For 3D, a tuple (nx, ny, nz). vertices : array-like For 2D, a 4x2 array; for 3D, an 8x3 array defining the coordinates of the domain's corner points. element_type : str For 2D, either "quad" or "tri"; for 3D, either "brick" or "tet". order : int, optional The polynomial order of the elements. Currently, only linear elements (order == 1) are supported. Default is 1. Returns ------- coords : jnp.ndarray An array of node coordinates. Its shape is ((n+1)*... x dim), where 'dim' represents the spatial dimension (2 or 3). elements : jnp.ndarray An array of element connectivity. For quadrilaterals/brick elements each row lists the indices of the nodes forming the element, while for simplex elements each row lists the indices forming a triangle (3 nodes) or tetrahedron (4 nodes). Notes ----- For 2D: - The vertices should be provided as a 4x2 array in anti-clockwise order. - The mapping from a reference square ([-1, 1] x [-1, 1]) to the physical domain is performed using bilinear interpolation. For 3D: - The vertices should be provided as an 8x3 array, with the ordering corresponding to a standard hexahedron (e.g., starting with (-1, -1, -1) and proceeding in an anti-clockwise fashion on the bottom face, then defining the top face). - The mapping from a reference cube ([-1, 1]^3) to the physical domain is done via trilinear interpolation. Currently, only linear (order == 1) elements are supported. """ vertices = jnp.asarray(vertices) dim = vertices.shape[1] if order != 1: raise NotImplementedError("Only order==1 is implemented at this moment.") # ----- 2D Mesh Generation ----- if dim == 2: if element_type not in ["quad", "tri"]: raise NotImplementedError("For 2D, element_type must be either 'quad' or 'tri'.") nx, ny = n_elements # Create a reference grid in [-1,1] x [-1,1] s = jnp.linspace(-1, 1, nx + 1) t = jnp.linspace(-1, 1, ny + 1) S, T = jnp.meshgrid(s, t, indexing="ij") ref_coords = jnp.column_stack([S.ravel(), T.ravel()]) # Bilinear mapping from reference coordinates to physical coordinates. def bilinear_interpolate(pt): s, t = pt return ((1 - s) * (1 - t) * vertices[0] + (1 + s) * (1 - t) * vertices[1] + (1 + s) * (1 + t) * vertices[2] + (1 - s) * (1 + t) * vertices[3]) / 4 coords = jax.vmap(bilinear_interpolate)(ref_coords) # Function to generate connectivity for one quadrilateral element. def quad_indices(i, j): return jnp.array([ i * (ny + 1) + j, i * (ny + 1) + (j + 1), (i + 1) * (ny + 1) + (j + 1), (i + 1) * (ny + 1) + j, ]) I, J = jnp.meshgrid(jnp.arange(nx), jnp.arange(ny), indexing="ij") I_flat = I.ravel() J_flat = J.ravel() quads = jax.vmap(quad_indices)(I_flat, J_flat) if element_type == "quad": elements = quads else: # element_type == "tri" # Split each quadrilateral into two triangles. # Here we split along the diagonal from the first to the third node. tri1 = quads[:, [0, 1, 2]] tri2 = quads[:, [0, 2, 3]] elements = jnp.concatenate([tri1, tri2], axis=0) return coords, elements # ----- 3D Mesh Generation ----- elif dim == 3: if element_type not in ["brick", "tet"]: raise NotImplementedError("For 3D, element_type must be either 'brick' or 'tet'.") nx, ny, nz = n_elements # Create a reference grid in [-1,1] x [-1,1] x [-1,1] s = jnp.linspace(-1, 1, nx + 1) t = jnp.linspace(-1, 1, ny + 1) u = jnp.linspace(-1, 1, nz + 1) S, T, U = jnp.meshgrid(s, t, u, indexing="ij") ref_coords = jnp.column_stack([S.ravel(), T.ravel(), U.ravel()]) # Trilinear mapping from reference coordinates to physical coordinates. def trilinear_interpolate(pt): s, t, u = pt return ((1 - s) * (1 - t) * (1 - u) * vertices[0] + (1 + s) * (1 - t) * (1 - u) * vertices[1] + (1 + s) * (1 + t) * (1 - u) * vertices[2] + (1 - s) * (1 + t) * (1 - u) * vertices[3] + (1 - s) * (1 - t) * (1 + u) * vertices[4] + (1 + s) * (1 - t) * (1 + u) * vertices[5] + (1 + s) * (1 + t) * (1 + u) * vertices[6] + (1 - s) * (1 + t) * (1 + u) * vertices[7]) / 8 coords = jax.vmap(trilinear_interpolate)(ref_coords) # Function to generate connectivity for one brick element. # We assume the local node ordering for the brick is: # n0: (i, j, k) # n1: (i+1, j, k) # n2: (i+1, j+1, k) # n3: (i, j+1, k) # n4: (i, j, k+1) # n5: (i+1, j, k+1) # n6: (i+1, j+1, k+1) # n7: (i, j+1, k+1) def brick_indices(i, j, k): n0 = i * (ny + 1) * (nz + 1) + j * (nz + 1) + k n1 = (i + 1) * (ny + 1) * (nz + 1) + j * (nz + 1) + k n2 = (i + 1) * (ny + 1) * (nz + 1) + (j + 1) * (nz + 1) + k n3 = i * (ny + 1) * (nz + 1) + (j + 1) * (nz + 1) + k n4 = i * (ny + 1) * (nz + 1) + j * (nz + 1) + (k + 1) n5 = (i + 1) * (ny + 1) * (nz + 1) + j * (nz + 1) + (k + 1) n6 = (i + 1) * (ny + 1) * (nz + 1) + (j + 1) * (nz + 1) + (k + 1) n7 = i * (ny + 1) * (nz + 1) + (j + 1) * (nz + 1) + (k + 1) return jnp.array([n0, n1, n2, n3, n4, n5, n6, n7]) I, J, K = jnp.meshgrid(jnp.arange(nx), jnp.arange(ny), jnp.arange(nz), indexing="ij") I_flat = I.ravel() J_flat = J.ravel() K_flat = K.ravel() bricks = jax.vmap(brick_indices)(I_flat, J_flat, K_flat) if element_type == "brick": elements = bricks else: # element_type == "tet" # For each brick element, subdivide it into 6 tetrahedra. # One common partition uses the opposite corners n0 and n6 as common vertices: # tet1: [n0, n1, n2, n6] # tet2: [n0, n2, n3, n6] # tet3: [n0, n3, n7, n6] # tet4: [n0, n7, n4, n6] # tet5: [n0, n4, n5, n6] # tet6: [n0, n5, n1, n6] def brick_to_tets(brick): n0, n1, n2, n3, n4, n5, n6, n7 = brick tet1 = jnp.array([n0, n1, n2, n6]) tet2 = jnp.array([n0, n2, n3, n6]) tet3 = jnp.array([n0, n3, n7, n6]) tet4 = jnp.array([n0, n7, n4, n6]) tet5 = jnp.array([n0, n4, n5, n6]) tet6 = jnp.array([n0, n5, n1, n6]) return jnp.stack([tet1, tet2, tet3, tet4, tet5, tet6]) # Map the subdivision over all brick elements. tets_list = jax.vmap(brick_to_tets)(bricks) # Reshape so that each tetrahedron is a row. n_bricks = bricks.shape[0] elements = tets_list.reshape(n_bricks * 6, 4) return coords, elements else: raise ValueError("Unsupported dimension: vertices must have 2 or 3 columns.")
def _elevate_order_triangle(coords, elements): """ Elevates a linear (3-node) triangular mesh to quadratic (6-node) triangles. Args: coords (array-like): (N,2) array of node coordinates. elements (array-like): (M,3) array of triangle connectivity (node indices). Returns: new_coords (np.ndarray): Updated node coordinates. new_elements (np.ndarray): (M,6) connectivity for quadratic triangles. """ coords = np.array(coords) elements = np.array(elements) new_coords = list(coords) # start with the original nodes edge_dict = {} # to store computed midpoints new_elements = [] for tri in elements: n0, n1, n2 = tri def get_midpoint(i, j): key = tuple(sorted((i, j))) if key not in edge_dict: mid = 0.5 * (coords[i] + coords[j]) edge_dict[key] = len(new_coords) new_coords.append(mid) return edge_dict[key] m01 = get_midpoint(n0, n1) m12 = get_midpoint(n1, n2) m20 = get_midpoint(n2, n0) new_elements.append([n0, n1, n2, m01, m12, m20]) return np.array(new_coords), np.array(new_elements) def _elevate_order_brick(coords, elements): """ Elevates a linear (8-node) brick mesh to quadratic bricks (27 nodes) Args: coords (array-like): (N,3) array of node coordinates. elements (array-like): (M,8) array of brick connectivity (node indices). Returns: new_coords (np.ndarray): Updated node coordinates including extra nodes. new_elements (np.ndarray): (M,27) connectivity for quadratic bricks. """ coords = np.array(coords) elements = np.array(elements) new_coords = list(coords) # original nodes edge_dict = {} # global dictionary for edge midpoints face_dict = {} # global dictionary for face nodes new_elements = [] for brick in elements: local = brick # original 8 corner indices # Helper functions def get_edge(i, j): key = tuple(sorted((local[i], local[j]))) if key not in edge_dict: mid = 0.5 * (coords[local[i]] + coords[local[j]]) edge_dict[key] = len(new_coords) new_coords.append(mid) return edge_dict[key] def get_face(face_nodes): key = tuple(sorted(face_nodes)) if key not in face_dict: face_coord = np.mean([coords[idx] for idx in face_nodes], axis=0) face_dict[key] = len(new_coords) new_coords.append(face_coord) return face_dict[key] # Define corner nodes (order-1 nodes) according to the assumed ordering. c0 = local[0] # (-1,-1,-1) c1 = local[1] # ( 1,-1,-1) c2 = local[2] # ( 1, 1,-1) c3 = local[3] # (-1, 1,-1) c4 = local[4] # (-1,-1, 1) c5 = local[5] # ( 1,-1, 1) c6 = local[6] # ( 1, 1, 1) c7 = local[7] # (-1, 1, 1) # Compute edge nodes in Gmsh ordering: e0 = get_edge(0, 1) # edge from c0 to c1 e1 = get_edge(1, 2) # edge from c1 to c2 e2 = get_edge(2, 3) # edge from c2 to c3 e3 = get_edge(3, 0) # edge from c3 to c0 e4 = get_edge(4, 5) # edge from c4 to c5 e5 = get_edge(5, 6) # edge from c5 to c6 e6 = get_edge(6, 7) # edge from c6 to c7 e7 = get_edge(7, 4) # edge from c7 to c4 e8 = get_edge(0, 4) # edge from c0 to c4 e9 = get_edge(1, 5) # edge from c1 to c5 e10 = get_edge(2, 6) # edge from c2 to c6 e11 = get_edge(3, 7) # edge from c3 to c7 # Compute face nodes in Gmsh ordering: f_bottom = get_face((local[0], local[1], local[2], local[3])) # bottom face (c0,c1,c2,c3) f_top = get_face((local[4], local[5], local[6], local[7])) # top face (c4,c5,c6,c7) f_front = get_face((local[0], local[1], local[5], local[4])) # front face (c0,c1,c5,c4) f_right = get_face((local[1], local[2], local[6], local[5])) # right face (c1,c2,c6,c5) f_back = get_face((local[2], local[3], local[7], local[6])) # back face (c2,c3,c7,c6) f_left = get_face((local[3], local[0], local[4], local[7])) # left face (c3,c0,c4,c7) # Compute interior node (average of the 8 corners) interior_coord = np.mean([coords[idx] for idx in local], axis=0) interior_idx = len(new_coords) new_coords.append(interior_coord) # Corners, then edges, then faces, then interior. new_connectivity = [c0, c1, c2, c3, c4, c5, c6, c7, e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, f_left, f_right, f_front, f_back, f_bottom, f_top, interior_idx] new_elements.append(new_connectivity) return np.array(new_coords), np.array(new_elements)
[docs]def elevate_mesh_order(coords, elements): """ Upgrades a mesh from order 1 to order 2. Supported cases: - 2D triangles (3 nodes -> 6 nodes) - 3D bricks (8 nodes -> 27 nodes) Args: coords (array-like): Array of node coordinates. elements (array-like): Element connectivity array. Returns: new_coords, new_elements: The upgraded mesh. Raises: NotImplementedError: If the element type is not supported. """ coords = np.array(coords) elements = np.array(elements) dim = coords.shape[1] n_nodes_per_elem = elements.shape[1] if dim == 2 and n_nodes_per_elem == 3: return _elevate_order_triangle(coords, elements) elif dim == 3 and n_nodes_per_elem == 8: return _elevate_order_brick(coords, elements) else: raise NotImplementedError("Mesh elevation to order 2 not implemented for this element type.")