Source code for torax.fvm.implicit_solve_block

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The `implicit_solve_block` function.

See function docstring for details.
"""
import dataclasses
import functools
import jax
from jax import numpy as jnp
from torax import jax_utils
from torax.fvm import block_1d_coeffs
from torax.fvm import cell_variable
from torax.fvm import fvm_conversions
from torax.fvm import residual_and_loss


[docs] @functools.partial( jax_utils.jit, static_argnames=[ 'convection_dirichlet_mode', 'convection_neumann_mode', 'theta_imp', ], ) def implicit_solve_block( dt: jax.Array, x_old: tuple[cell_variable.CellVariable, ...], x_new_guess: tuple[cell_variable.CellVariable, ...], coeffs_old: block_1d_coeffs.Block1DCoeffs, coeffs_new: block_1d_coeffs.Block1DCoeffs, theta_imp: float = 1.0, convection_dirichlet_mode: str = 'ghost', convection_neumann_mode: str = 'ghost', ) -> tuple[cell_variable.CellVariable, ...]: # pyformat: disable # pyformat removes line breaks needed for readability """Runs one time step of an implicit solver on the equation defined by `coeffs`. This solver is relatively generic in that it models diffusion, convection, etc. abstractly. The caller must do the problem-specific physics calculations to obtain the coefficients for a particular problem. Args: dt: Discrete time step. x_old: Tuple containing CellVariables for each channel with their values at x_new_guess: Tuple containing initial guess for x_new. coeffs_old: Coefficients defining the equation, computed for time t. coeffs_new: Coefficients defining the equation, computed for time t+dt. theta_imp: Coefficient in [0, 1] determining which solution method to use. We solve transient_coeff (x_new - x_old) / dt = theta_imp F(t_new) + (1 - theta_imp) F(t_old). Three values of theta_imp correspond to named solution methods: theta_imp = 1: Backward Euler implicit method (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit method convection_dirichlet_mode: See docstring of the `convection_terms` function, `dirichlet_mode` argument. convection_neumann_mode: See docstring of the `convection_terms` function, `neumann_mode` argument. Returns: x_new: Tuple, with x_new[i] giving channel i of x at the next time step """ # pyformat: enable # In the linear case, we can use the same matrix formulation from the # nonlinear case but instead use linalg.solve to directly solve # residual, where the implicit coefficients are calculated with # an approximation of x_new, e.g. x_old for a single-step linear solve, # or from Picard iterations with predictor-corrector. # See residual_and_loss.theta_method_matrix_equation for a complete # description of how the equation is set up. x_old_vec = fvm_conversions.cell_variable_tuple_to_vec(x_old) lhs_mat, lhs_vec, rhs_mat, rhs_vec = ( residual_and_loss.theta_method_matrix_equation( dt=dt, x_old=x_old, x_new_guess=x_new_guess, coeffs_old=coeffs_old, coeffs_new=coeffs_new, theta_imp=theta_imp, convection_dirichlet_mode=convection_dirichlet_mode, convection_neumann_mode=convection_neumann_mode, ) ) rhs = jnp.dot(rhs_mat, x_old_vec) + rhs_vec - lhs_vec x_new = jnp.linalg.solve(lhs_mat, rhs) # Create updated CellVariable instances based on state_plus_dt which has # updated boundary conditions and prescribed profiles. x_new = jnp.split(x_new, len(x_old)) out = [ dataclasses.replace(var, value=value) for var, value in zip(x_new_guess, x_new) ] out = tuple(out) return out