Source code for torax.fvm.newton_raphson_solve_block

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The `newton_raphson_solve_block` function.

See function docstring for details.
"""

import functools
from typing import Callable, Final

from absl import logging
import jax
from jax import numpy as jnp
import numpy as np
from torax import jax_utils
from torax import state as state_module
from torax.config import runtime_params_slice
from torax.fvm import block_1d_coeffs
from torax.fvm import calc_coeffs
from torax.fvm import cell_variable
from torax.fvm import enums
from torax.fvm import fvm_conversions
from torax.fvm import residual_and_loss
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib


# Delta is a vector. If no entry of delta is above this magnitude, we terminate
# the delta loop. This is to avoid getting stuck in an infinite loop in edge
# cases with bad numerics.
MIN_DELTA: Final[float] = 1e-7


def _log_iterations(
    residual: jax.Array,
    iterations: jax.Array,
    delta_reduction: jax.Array | None = None,
    dt: jax.Array | None = None,
) -> None:
  """Logs info on internal Newton-Raphson iterations.

  Args:
    residual: Scalar residual.
    iterations: Number of iterations taken so far in the solve block.
    delta_reduction: Current tau used in this iteration.
    dt: Current dt used in this iteration.
  """
  if dt is not None:
    logging.info(
        'Iteration: %d. Residual: %.16f. dt = %.6f',
        iterations,
        residual,
        dt,
    )

  elif delta_reduction is not None:
    logging.info(
        'Iteration: %d. Residual: %.16f. tau = %.6f',
        iterations,
        residual,
        delta_reduction,
    )
  else:
    logging.info('Iteration: %d. Residual: %.16f', iterations, residual)


[docs] def newton_raphson_solve_block( dt: jax.Array, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo_t: geometry.Geometry, geo_t_plus_dt: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], core_profiles_t: state_module.CoreProfiles, core_profiles_t_plus_dt: state_module.CoreProfiles, transport_model: transport_model_lib.TransportModel, explicit_source_profiles: source_profiles.SourceProfiles, source_models: source_models_lib.SourceModels, pedestal_model: pedestal_model_lib.PedestalModel, coeffs_callback: calc_coeffs.CoeffsCallback, evolving_names: tuple[str, ...], initial_guess_mode: enums.InitialGuessMode, maxiter: int, tol: float, coarse_tol: float, delta_reduction_factor: float, tau_min: float, log_iterations: bool = False, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state_module.StepperNumericOutputs, block_1d_coeffs.AuxiliaryOutput, ]: # pyformat: disable # pyformat removes line breaks needed for reability """Runs one time step of a Newton-Raphson based root-finding on the equation defined by `coeffs`. This solver is relatively generic in that it models diffusion, convection, etc. abstractly. The caller must do the problem-specific physics calculations to obtain the coefficients for a particular problem. This solver uses iterative root finding on the linearized residual between two sides of the equation describing a theta method update. The linearized residual for a trial x_new is: R(x_old) + jacobian(R(x_old))*(x_new - x_old) Setting delta = x_new - x_old, we solve the linear system: A*x_new = b, with A = jacobian(R(x_old)), b = A*x_old - R(x_old) Each successive iteration sets x_new = x_old - delta, until the residual or delta is under a tolerance (tol). If either the delta step leads to an unphysical state, represented by NaNs in the residual, or if the residual doesn't shrink following the delta step, then delta is successively reduced by a delta_reduction_factor. If tau = delta_now / delta_original is below a tolerance, then the iterations stop. If residual > tol then the function exits with an error flag, producing either a warning or recalculation with a lower dt. Args: dt: Discrete time step. static_runtime_params_slice: Static runtime parameters. Changes to these runtime params will trigger recompilation. dynamic_runtime_params_slice_t: Runtime parameters for time t (the start time of the step). These config params can change from step to step without triggering a recompilation. dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + dt. geo_t: Geometry at time t. geo_t_plus_dt: Geometry at time t + dt. x_old: Tuple containing CellVariables for each channel with their values at the start of the time step. core_profiles_t: Core plasma profiles which contain all available prescribed quantities at the start of the time step. This includes evolving boundary conditions and prescribed time-dependent profiles that are not being evolved by the PDE system. core_profiles_t_plus_dt: Core plasma profiles which contain all available prescribed quantities at the end of the time step. This includes evolving boundary conditions and prescribed time-dependent profiles that are not being evolved by the PDE system. transport_model: Turbulent transport model callable. explicit_source_profiles: Pre-calculated sources implemented as explicit sources in the PDE. source_models: Collection of source callables to generate source PDE coefficients. pedestal_model: Model of the pedestal's behavior. coeffs_callback: Calculates diffusion, convection etc. coefficients given a core_profiles. Repeatedly called by the iterative optimizer. evolving_names: The names of variables within the core profiles that should evolve. initial_guess_mode: chooses the initial_guess for the iterative method, either x_old or linear step. When taking the linear step, it is also recommended to use Pereverzev-Corrigan terms if the transport coefficients are stiff, e.g. from QLKNN. This can be set by setting use_pereverzev = True in the solver config. maxiter: Quit iterating after this many iterations reached. tol: Quit iterating after the average absolute value of the residual is <= tol. coarse_tol: Coarser allowed tolerance for cases when solver develops small steps in the vicinity of the solution. delta_reduction_factor: Multiply by delta_reduction_factor after each failed line search step. tau_min: Minimum delta/delta_original allowed before the newton raphson routine resets at a lower timestep. log_iterations: If true, output diagnostic information from within iteration loop. Returns: x_new: Tuple, with x_new[i] giving channel i of x at the next time step stepper_numeric_outputs: state_module.StepperNumericOutputs. Iteration and error info. For the error, 0 signifies residual < tol at exit, 1 signifies residual > tol, steps became small. aux_output: Extra auxiliary output from calc_coeffs. """ # pyformat: enable coeffs_old = coeffs_callback( dynamic_runtime_params_slice_t, geo_t, core_profiles_t, x_old, explicit_call=True, ) match initial_guess_mode: # LINEAR initial guess will provide the initial guess using the predictor- # corrector method if predictor_corrector=True in the solver config case enums.InitialGuessMode.LINEAR: # returns transport coefficients with additional pereverzev terms # if set by runtime_params, needed if stiff transport models (e.g. qlknn) # are used. coeffs_exp_linear = coeffs_callback( dynamic_runtime_params_slice_t, geo_t, core_profiles_t, x_old, allow_pereverzev=True, explicit_call=True, ) # See linear_theta_method.py for comments on the predictor_corrector API x_new_guess = tuple( [core_profiles_t_plus_dt[name] for name in evolving_names] ) init_x_new, _ = predictor_corrector_method.predictor_corrector_method( dt=dt, static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, x_old=x_old, x_new_guess=x_new_guess, core_profiles_t_plus_dt=core_profiles_t_plus_dt, coeffs_exp=coeffs_exp_linear, coeffs_callback=coeffs_callback, ) init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new) case enums.InitialGuessMode.X_OLD: init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(x_old) case _: raise ValueError( f'Unknown option for first guess in iterations: {initial_guess_mode}' ) # Create a residual() function with only one argument: x_new. # The other arguments (dt, x_old, etc.) are fixed. # Note that core_profiles_t_plus_dt only contains the known quantities at # t_plus_dt, e.g. boundary conditions and prescribed profiles. residual_fun = functools.partial( residual_and_loss.theta_method_block_residual, dt=dt, static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, transport_model=transport_model, explicit_source_profiles=explicit_source_profiles, source_models=source_models, coeffs_old=coeffs_old, evolving_names=evolving_names, pedestal_model=pedestal_model, ) jacobian_fun = functools.partial( residual_and_loss.theta_method_block_jacobian, dt=dt, static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, evolving_names=evolving_names, transport_model=transport_model, pedestal_model=pedestal_model, explicit_source_profiles=explicit_source_profiles, source_models=source_models, coeffs_old=coeffs_old, ) cond_fun = functools.partial(cond, tol=tol, tau_min=tau_min, maxiter=maxiter) delta_cond_fun = functools.partial( delta_cond, residual_fun=residual_fun, ) body_fun = functools.partial( body, jacobian_fun=jacobian_fun, delta_cond_fun=delta_cond_fun, delta_reduction_factor=delta_reduction_factor, log_iterations=log_iterations, ) # initialize state dict being passed around Newton-Raphson iterations residual_vec_init_x_new, aux_output_init_x_new = residual_fun(init_x_new_vec) initial_state = { 'x': init_x_new_vec, 'iterations': jnp.array(0, dtype=jax_utils.get_int_dtype()), 'residual': residual_vec_init_x_new, 'last_tau': jnp.array(1.0, dtype=jax_utils.get_dtype()), 'aux_output': aux_output_init_x_new, } # log initial state if requested if log_iterations: _log_iterations( residual=residual_scalar(initial_state['residual']), iterations=initial_state['iterations'], dt=dt, ) # carry out iterations. jax.lax.while needed for JAX-compliance output_state = jax_utils.py_while(cond_fun, body_fun, initial_state) # Create updated CellVariable instances based on state_plus_dt which has # updated boundary conditions and prescribed profiles. x_new = fvm_conversions.vec_to_cell_variable_tuple( output_state['x'], core_profiles_t_plus_dt, evolving_names ) # Tell the caller whether or not x_new successfully reduces the residual below # the tolerance by providing an extra output, error. # error = 0: residual converged within fine tolerance (tol) # error = 1: not converged. Possibly backtrack to smaller dt and retry # error = 2: residual not strictly converged but is still within reasonable # tolerance (coarse_tol). Can occur when solver exits early due to small steps # in solution vicinity. Proceed but provide a warning to user. error = jax_utils.py_cond( residual_scalar(output_state['residual']) < tol, lambda: 0, # Called when True lambda: jax_utils.py_cond( # Called when False residual_scalar(output_state['residual']) < coarse_tol, lambda: 2, # Called when True lambda: 1, # Called when False ), ) stepper_numeric_outputs = state_module.StepperNumericOutputs( inner_solver_iterations=int(output_state['iterations']), stepper_error_state=error, outer_stepper_iterations=1, ) coeffs_final = coeffs_callback( dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt, core_profiles_t_plus_dt, x_new, allow_pereverzev=True, ) return x_new, stepper_numeric_outputs, coeffs_final.auxiliary_outputs
def residual_scalar(x): return np.mean(np.abs(x))
[docs] def cond( state: dict[str, jax.Array], tau_min: float, maxiter: int, tol: float, ) -> bool: """Check if exit condition reached for Newton-Raphson iterations.""" iteration = state['iterations'][...] return jnp.bool_( jnp.logical_and( jnp.logical_and( residual_scalar(state['residual']) > tol, iteration < maxiter ), state['last_tau'] > tau_min, ) )
[docs] def body( input_state: dict[str, jax.Array], jacobian_fun, delta_cond_fun, delta_reduction_factor, log_iterations, ) -> dict[str, jax.Array]: """Calculates next guess in Newton-Raphson iteration.""" delta_body_fun = functools.partial( delta_body, delta_reduction_factor=delta_reduction_factor, ) a_mat, _ = jacobian_fun(input_state['x']) # Ignore the aux output here. rhs = -input_state['residual'] # delta = x_new - x_old # tau = delta/delta0, where delta0 is the delta that sets the linearized # residual to zero. tau < 1 when needed such that x_new meets # conditions of reduced residual and valid state quantities. # If tau < taumin while residual > tol, then the routine exits with an # error flag, leading to either a warning or recalculation at lower dt initial_delta_state = { 'x': input_state['x'], 'delta': jnp.linalg.solve(a_mat, rhs), 'residual_old': input_state['residual'], 'residual_new': input_state['residual'], 'aux_output_new': input_state['aux_output'], 'tau': jnp.array(1.0, dtype=jax_utils.get_dtype()), } output_delta_state = jax_utils.py_while( delta_cond_fun, delta_body_fun, initial_delta_state ) output_state = { 'x': input_state['x'] + output_delta_state['delta'], 'residual': output_delta_state['residual_new'], 'iterations': ( jnp.array( input_state['iterations'][...], dtype=jax_utils.get_int_dtype() ) + 1 ), 'last_tau': output_delta_state['tau'], 'aux_output': output_delta_state['aux_output_new'], } if log_iterations: _log_iterations( residual=residual_scalar(output_state['residual']), iterations=output_state['iterations'], delta_reduction=output_delta_state['tau'], ) return output_state
[docs] def delta_cond( delta_state: dict[str, jax.Array], residual_fun: Callable[[jax.Array], jax.Array], ) -> bool: """Check if delta obtained from Newton step is valid. Args: delta_state: see `delta_body`. residual_fun: Residual function. Returns: True if the new value of `x` causes any NaNs or has increased the residual relative to the old value of `x`. """ x_old = delta_state['x'] x_new = x_old + delta_state['delta'] residual_vec_x_old = delta_state['residual_old'] residual_scalar_x_old = residual_scalar(residual_vec_x_old) # Avoid sanity checking inside residual, since we directly # afterwards check sanity on the output (NaN checking) # TODO(b/312453092) consider instead sanity-checking x_new with jax_utils.enable_errors(False): residual_vec_x_new, aux_output_x_new = residual_fun(x_new) residual_scalar_x_new = residual_scalar(residual_vec_x_new) delta_state['residual_new'] = residual_vec_x_new delta_state['aux_output_new'] = aux_output_x_new return jnp.bool_( jnp.logical_and( jnp.max(delta_state['delta']) > MIN_DELTA, jnp.logical_or( residual_scalar_x_old < residual_scalar_x_new, jnp.isnan(residual_scalar_x_new), ), ), )
[docs] def delta_body( input_delta_state: dict[str, jax.Array], delta_reduction_factor: float ) -> dict[str, jax.Array]: """Reduces step size for this Newton iteration.""" return input_delta_state | dict( delta=input_delta_state['delta'] * delta_reduction_factor, tau=jnp.array(input_delta_state['tau'][...], dtype=jax_utils.get_dtype()) * delta_reduction_factor, )