# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pydantic config for Stepper."""
import abc
import functools
from typing import Literal
import pydantic
from torax.fvm import enums
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.stepper import linear_theta_method
from torax.stepper import nonlinear_theta_method
from torax.stepper import runtime_params
from torax.stepper import stepper as stepper_lib
from torax.torax_pydantic import torax_pydantic
from torax.transport_model import transport_model as transport_model_lib
# pylint: disable=invalid-name
[docs]
class BaseStepper(torax_pydantic.BaseModelFrozen, abc.ABC):
"""Base class for stepper configs.
Attributes:
theta_imp: The theta value in the theta method 0 = explicit, 1 = fully
implicit, 0.5 = Crank-Nicolson.
predictor_corrector: Enables predictor_corrector iterations with the linear
solver. If False, compilation is faster.
corrector_steps: The number of corrector steps for the predictor-corrector
linear solver. 0 means a pure linear solve with no corrector steps.
convection_dirichlet_mode: See `fvm.convection_terms` docstring,
`dirichlet_mode` argument.
convection_neumann_mode: See `fvm.convection_terms` docstring,
`neumann_mode` argument.
use_pereverzev: Use pereverzev terms for linear solver. Is only applied in
the nonlinear solver for the optional initial guess from the linear solver
chi_per: (deliberately) large heat conductivity for Pereverzev rule.
d_per: (deliberately) large particle diffusion for Pereverzev rule.
"""
theta_imp: torax_pydantic.UnitInterval = 1.0
predictor_corrector: bool = False
corrector_steps: pydantic.PositiveInt = 1
convection_dirichlet_mode: Literal['ghost', 'direct', 'semi-implicit'] = (
'ghost'
)
convection_neumann_mode: Literal['ghost', 'semi-implicit'] = 'ghost'
use_pereverzev: bool = False
chi_per: pydantic.PositiveFloat = 20.0
d_per: pydantic.NonNegativeFloat = 10.0
@property
@abc.abstractmethod
def build_dynamic_params(self) -> runtime_params.DynamicRuntimeParams:
"""Builds dynamic runtime params from the config."""
[docs]
def build_static_params(self) -> runtime_params.StaticRuntimeParams:
"""Builds static runtime params from the config."""
return runtime_params.StaticRuntimeParams(
theta_imp=self.theta_imp,
convection_dirichlet_mode=self.convection_dirichlet_mode,
convection_neumann_mode=self.convection_neumann_mode,
use_pereverzev=self.use_pereverzev,
predictor_corrector=self.predictor_corrector,
)
[docs]
@abc.abstractmethod
def build_stepper(
self,
transport_model: transport_model_lib.TransportModel,
source_models: source_models_lib.SourceModels,
pedestal_model: pedestal_model_lib.PedestalModel,
) -> stepper_lib.Stepper:
"""Builds a stepper from the config."""
@property
@abc.abstractmethod
def linear_solver(self) -> bool:
"""Returns True if the stepper is a linear solver."""
[docs]
class LinearThetaMethod(BaseStepper):
"""Model for the linear stepper.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'linear'.
"""
stepper_type: Literal['linear'] = 'linear'
@functools.cached_property
def build_dynamic_params(self) -> runtime_params.DynamicRuntimeParams:
return runtime_params.DynamicRuntimeParams(
chi_per=self.chi_per,
d_per=self.d_per,
corrector_steps=self.corrector_steps,
)
[docs]
def build_stepper(
self,
transport_model: transport_model_lib.TransportModel,
source_models: source_models_lib.SourceModels,
pedestal_model: pedestal_model_lib.PedestalModel,
) -> stepper_lib.Stepper:
return linear_theta_method.LinearThetaMethod(
transport_model=transport_model,
source_models=source_models,
pedestal_model=pedestal_model,
)
@property
def linear_solver(self) -> bool:
return True
[docs]
class NewtonRaphsonThetaMethod(BaseStepper):
"""Model for nonlinear Newton-Raphson stepper.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'newton_raphson'.
log_iterations: If True, log internal iterations in Newton-Raphson solver.
initial_guess_mode: The initial guess mode for the Newton-Raphson solver.
maxiter: The maximum number of iterations for the Newton-Raphson solver.
tol: The tolerance for the Newton-Raphson solver.
coarse_tol: The coarse tolerance for the Newton-Raphson solver.
delta_reduction_factor: The delta reduction factor for the Newton-Raphson
solver.
tau_min: The minimum value of tau for the Newton-Raphson solver.
"""
stepper_type: Literal['newton_raphson'] = 'newton_raphson'
log_iterations: bool = False
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: pydantic.NonNegativeInt = 30
tol: float = 1e-5
coarse_tol: float = 1e-2
delta_reduction_factor: float = 0.5
tau_min: float = 0.01
@property
def linear_solver(self) -> bool:
return self.initial_guess_mode == enums.InitialGuessMode.LINEAR
@functools.cached_property
def build_dynamic_params(
self,
) -> nonlinear_theta_method.DynamicNewtonRaphsonRuntimeParams:
return nonlinear_theta_method.DynamicNewtonRaphsonRuntimeParams(
chi_per=self.chi_per,
d_per=self.d_per,
log_iterations=self.log_iterations,
initial_guess_mode=self.initial_guess_mode.value,
maxiter=self.maxiter,
tol=self.tol,
coarse_tol=self.coarse_tol,
corrector_steps=self.corrector_steps,
delta_reduction_factor=self.delta_reduction_factor,
tau_min=self.tau_min,
)
[docs]
def build_stepper(
self,
transport_model: transport_model_lib.TransportModel,
source_models: source_models_lib.SourceModels,
pedestal_model: pedestal_model_lib.PedestalModel,
) -> nonlinear_theta_method.NewtonRaphsonThetaMethod:
return nonlinear_theta_method.NewtonRaphsonThetaMethod(
transport_model=transport_model,
source_models=source_models,
pedestal_model=pedestal_model,
)
[docs]
class OptimizerThetaMethod(BaseStepper):
"""Model for nonlinear OptimizerThetaMethod stepper.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'optimizer'.
initial_guess_mode: The initial guess mode for the optimizer.
maxiter: The maximum number of iterations for the optimizer.
tol: The tolerance for the optimizer.
"""
stepper_type: Literal['optimizer'] = 'optimizer'
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: pydantic.NonNegativeInt = 100
tol: float = 1e-12
@property
def linear_solver(self) -> bool:
return self.initial_guess_mode == enums.InitialGuessMode.LINEAR
@functools.cached_property
def build_dynamic_params(
self,
) -> nonlinear_theta_method.DynamicOptimizerRuntimeParams:
return nonlinear_theta_method.DynamicOptimizerRuntimeParams(
chi_per=self.chi_per,
d_per=self.d_per,
initial_guess_mode=self.initial_guess_mode.value,
maxiter=self.maxiter,
tol=self.tol,
corrector_steps=self.corrector_steps,
)
[docs]
def build_stepper(
self,
transport_model: transport_model_lib.TransportModel,
source_models: source_models_lib.SourceModels,
pedestal_model: pedestal_model_lib.PedestalModel,
) -> nonlinear_theta_method.OptimizerThetaMethod:
return nonlinear_theta_method.OptimizerThetaMethod(
transport_model=transport_model,
source_models=source_models,
pedestal_model=pedestal_model,
)
StepperConfig = (
LinearThetaMethod | NewtonRaphsonThetaMethod | OptimizerThetaMethod
)