# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The `optimizer_solve_block` function.
See function docstring for details.
"""
from typing import TypeAlias
import jax
from torax import state
from torax.config import runtime_params_slice
from torax.fvm import block_1d_coeffs
from torax.fvm import calc_coeffs
from torax.fvm import cell_variable
from torax.fvm import enums
from torax.fvm import fvm_conversions
from torax.fvm import residual_and_loss
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles
from torax.stepper import predictor_corrector_method
from torax.transport_model import transport_model as transport_model_lib
AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput
[docs]
def optimizer_solve_block(
dt: jax.Array,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice,
geo_t: geometry.Geometry,
geo_t_plus_dt: geometry.Geometry,
x_old: tuple[cell_variable.CellVariable, ...],
core_profiles_t: state.CoreProfiles,
core_profiles_t_plus_dt: state.CoreProfiles,
transport_model: transport_model_lib.TransportModel,
explicit_source_profiles: source_profiles.SourceProfiles,
source_models: source_models_lib.SourceModels,
pedestal_model: pedestal_model_lib.PedestalModel,
coeffs_callback: calc_coeffs.CoeffsCallback,
evolving_names: tuple[str, ...],
initial_guess_mode: enums.InitialGuessMode,
maxiter: int,
tol: float,
) -> tuple[
tuple[cell_variable.CellVariable, ...],
state.StepperNumericOutputs,
block_1d_coeffs.AuxiliaryOutput,
]:
# pyformat: disable # pyformat removes line breaks needed for readability
"""Runs one time step of an optimization-based solver on the equation defined by `coeffs`.
This solver is relatively generic in that it models diffusion, convection,
etc. abstractly. The caller must do the problem-specific physics calculations
to obtain the coefficients for a particular problem.
This solver uses iterative optimization to minimize the norm of the residual
between two sides of the equation describing a theta method update.
Args:
dt: Discrete time step.
static_runtime_params_slice: Static runtime parameters. Changes to these
runtime params will trigger recompilation. A key parameter in this params
slice is theta_imp, a coefficient in [0, 1] determining which solution
method to use. We solve transient_coeff (x_new - x_old) / dt = theta_imp
F(t_new) + (1 - theta_imp) F(t_old). Three values of theta_imp correspond
to named solution methods: theta_imp = 1: Backward Euler implicit method
(default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler
explicit method.
dynamic_runtime_params_slice_t: Runtime params for time t (the start time of
the step). These runtime params can change from step to step without
triggering a recompilation.
dynamic_runtime_params_slice_t_plus_dt: Runtime params for time t + dt.
geo_t: Geometry object used to initialize auxiliary outputs at time t.
geo_t_plus_dt: Geometry object used to initialize auxiliary outputs at time
t + dt.
x_old: Tuple containing CellVariables for each channel with their values at
the start of the time step.
core_profiles_t: Core plasma profiles which contain all available prescribed
quantities at the start of the time step. This includes evolving boundary
conditions and prescribed time-dependent profiles that are not being
evolved by the PDE system.
core_profiles_t_plus_dt: Core plasma profiles which contain all available
prescribed quantities at the end of the time step. This includes evolving
boundary conditions and prescribed time-dependent profiles that are not
being evolved by the PDE system.
transport_model: Turbulent transport model callable.
explicit_source_profiles: Pre-calculated sources implemented as explicit
sources in the PDE.
source_models: Collection of source callables to generate source PDE
coefficients.
pedestal_model: Model of the pedestal's behavior.
coeffs_callback: Calculates diffusion, convection etc. coefficients given a
core_profiles. Repeatedly called by the iterative optimizer.
evolving_names: The names of variables within the core profiles that should
evolve.
initial_guess_mode: Chooses the initial_guess for the iterative method,
either x_old or linear step. When taking the linear step, it is also
recommended to use Pereverzev-Corrigan terms if the transport use
pereverzev terms for linear solver. Is only applied in the nonlinear
solver for the optional initial guess from the linear solver.
maxiter: See docstring of `jaxopt.LBFGS`.
tol: See docstring of `jaxopt.LBFGS`.
Returns:
x_new: Tuple, with x_new[i] giving channel i of x at the next time step
stepper_numeric_outputs: StepperNumericOutputs. Info about iterations and
errors
aux_output: Extra auxiliary output from the calc_coeffs.
"""
# pyformat: enable
coeffs_old = coeffs_callback(
dynamic_runtime_params_slice_t,
geo_t,
core_profiles_t,
x_old,
explicit_call=True,
)
match initial_guess_mode:
# LINEAR initial guess will provide the initial guess using the predictor-
# corrector method if predictor_corrector=True in the stepper runtime params
case enums.InitialGuessMode.LINEAR:
# returns transport coefficients with additional pereverzev terms
# if set by runtime_params, needed if stiff transport models (e.g. qlknn)
# are used.
coeffs_exp_linear = coeffs_callback(
dynamic_runtime_params_slice_t,
geo_t,
core_profiles_t,
x_old,
allow_pereverzev=True,
explicit_call=True,
)
# See linear_theta_method.py for comments on the predictor_corrector API
x_new_guess = tuple(
[core_profiles_t_plus_dt[name] for name in evolving_names]
)
init_x_new, _ = predictor_corrector_method.predictor_corrector_method(
dt=dt,
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
geo_t_plus_dt=geo_t_plus_dt,
x_old=x_old,
x_new_guess=x_new_guess,
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
coeffs_exp=coeffs_exp_linear,
coeffs_callback=coeffs_callback,
)
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
case enums.InitialGuessMode.X_OLD:
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(x_old)
case _:
raise ValueError(
f'Unknown option for first guess in iterations: {initial_guess_mode}'
)
stepper_numeric_outputs = state.StepperNumericOutputs()
# Advance jaxopt_solver by one timestep
(
x_new_vec,
final_loss,
_,
stepper_numeric_outputs.inner_solver_iterations,
) = residual_and_loss.jaxopt_solver(
dt=dt,
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
geo_t_plus_dt=geo_t_plus_dt,
x_old=x_old,
init_x_new_vec=init_x_new_vec,
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
transport_model=transport_model,
explicit_source_profiles=explicit_source_profiles,
source_models=source_models,
pedestal_model=pedestal_model,
coeffs_old=coeffs_old,
evolving_names=evolving_names,
maxiter=maxiter,
tol=tol,
)
# Create updated CellVariable instances based on core_profiles_t_plus_dt which
# has updated boundary conditions and prescribed profiles.
x_new = fvm_conversions.vec_to_cell_variable_tuple(
x_new_vec, core_profiles_t_plus_dt, evolving_names
)
# Tell the caller whether or not x_new successfully reduces the loss below
# the tolerance by providing an extra output, error.
stepper_numeric_outputs.stepper_error_state = jax.lax.cond(
final_loss > tol,
lambda: 1, # Called when True
lambda: 0, # Called when False
)
coeffs_final = coeffs_callback(
dynamic_runtime_params_slice_t_plus_dt,
geo_t_plus_dt,
core_profiles_t_plus_dt,
x_new,
allow_pereverzev=True,
)
return x_new, stepper_numeric_outputs, coeffs_final.auxiliary_outputs