# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pydantic config for Torax."""
import copy
import logging
from typing import Any, Mapping
import pydantic
from torax import version
from torax.config import numerics as numerics_lib
from torax.config import plasma_composition as plasma_composition_lib
from torax.config import profile_conditions as profile_conditions_lib
from torax.fvm import enums
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.mhd import pydantic_model as mhd_pydantic_model
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
from torax.sources import pydantic_model as sources_pydantic_model
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.time_step_calculator import pydantic_model as time_step_calculator_pydantic_model
from torax.torax_pydantic import file_restart as file_restart_pydantic_model
from torax.torax_pydantic import torax_pydantic
from torax.transport_model import pydantic_model as transport_model_pydantic_model
import typing_extensions
from typing_extensions import Self
[docs]
class ToraxConfig(torax_pydantic.BaseModelFrozen):
"""Base config class for Torax.
Attributes:
profile_conditions: Config for the profile conditions.
numerics: Config for the numerics.
plasma_composition: Config for the plasma composition.
geometry: Config for the geometry.
pedestal: Config for the pedestal model. If an empty dictionary is passed
in, the pedestal model will be set to `no_pedestal`.
sources: Config for the sources.
stepper: Config for the stepper. If an empty dictionary is passed in, the
stepper model will be set to `linear`.
transport: Config for the transport model. If an empty dictionary is passed
in, the transport model will be set to `constant`.
mhd: Optional config for mhd models. If None, no MHD models are used.
time_step_calculator: Optional config for the time step calculator. If not
provided the default chi time step calculator is used.
restart: Optional config for file restart. If None, no file restart is
performed.
"""
profile_conditions: profile_conditions_lib.ProfileConditions
numerics: numerics_lib.Numerics
plasma_composition: plasma_composition_lib.PlasmaComposition
geometry: geometry_pydantic_model.Geometry
sources: sources_pydantic_model.Sources
stepper: stepper_pydantic_model.StepperConfig = pydantic.Field(
discriminator='stepper_type'
)
transport: transport_model_pydantic_model.TransportConfig = pydantic.Field(
discriminator='transport_model'
)
pedestal: pedestal_pydantic_model.PedestalConfig = pydantic.Field(
discriminator='pedestal_model'
)
mhd: mhd_pydantic_model.MHD = mhd_pydantic_model.MHD()
time_step_calculator: (
time_step_calculator_pydantic_model.TimeStepCalculator
) = time_step_calculator_pydantic_model.TimeStepCalculator()
restart: file_restart_pydantic_model.FileRestart | None = pydantic.Field(
default=None
)
@pydantic.model_validator(mode='before')
@classmethod
def _unpack_runtime_params(cls, data: dict[str, Any]) -> dict[str, Any]:
# # TODO(b/401187494): Remove this once the test configs are updated.
if 'runtime_params' in data:
new_data = copy.deepcopy(data)
runtime_params = new_data.pop('runtime_params')
new_data['profile_conditions'] = runtime_params.get(
'profile_conditions', {}
)
new_data['numerics'] = runtime_params.get('numerics', {})
new_data['plasma_composition'] = runtime_params.get(
'plasma_composition', {}
)
return new_data
return data
@pydantic.model_validator(mode='before')
@classmethod
def _defaults(cls, data: dict[str, Any]) -> dict[str, Any]:
configurable_data = copy.deepcopy(data)
if 'pedestal_model' not in configurable_data['pedestal']:
configurable_data['pedestal']['pedestal_model'] = 'no_pedestal'
if 'transport_model' not in configurable_data['transport']:
configurable_data['transport']['transport_model'] = 'constant'
if 'stepper_type' not in configurable_data['stepper']:
configurable_data['stepper']['stepper_type'] = 'linear'
return configurable_data
@pydantic.model_validator(mode='after')
def _check_fields(self) -> typing_extensions.Self:
using_nonlinear_transport_model = self.transport.transport_model in [
'qlknn',
'CGM',
]
using_linear_solver = isinstance(
self.stepper, stepper_pydantic_model.LinearThetaMethod
)
initial_guess_mode_is_linear = (
False # pylint: disable=g-long-ternary
if using_linear_solver
else self.stepper.initial_guess_mode
== enums.InitialGuessMode.LINEAR
)
if (
using_nonlinear_transport_model
and (using_linear_solver or initial_guess_mode_is_linear)
and not self.stepper.use_pereverzev
):
logging.warning("""
use_pereverzev=False in a configuration where setting
use_pereverzev=True is recommended.
A nonlinear transport model is used. However, a linear solver is also
being used, either directly, or to provide an initial guess for a
nonlinear solver.
With this configuration, it is strongly recommended to set
use_pereverzev=True to avoid numerical instability in the solver.
""")
return self
[docs]
def update_fields(self, x: Mapping[str, Any]):
"""Safely update fields in the config.
This works with Frozen models.
This method will invalidate all `functools.cached_property` caches of
all ancestral models in the nested tree, as these could have a dependency
on the updated model. In addition, these nodes will be re-validated.
Args:
x: A dictionary whose key is a path `'some.path.to.field_name'` and the
`value` is the new value for `field_name`. The path can be dictionary
keys or attribute names, but `field_name` must be an attribute of a
Pydantic model.
Raises:
ValueError: all submodels must be unique object instances. A `ValueError`
will be raised if this is not the case.
"""
self._update_fields(x)
mesh = self.geometry.build_provider.torax_mesh
if _is_nrho_updated(x):
# Clear the cached properties of all submodels, as the n_rho may have
# changed. Also force the grid to be set, as the grid is dependent on the
# n_rho.
for model in self.submodels:
model.clear_cached_properties()
torax_pydantic.set_grid(self, mesh, mode='force')
else:
# Update the grid on any new models which are added and have not had their
# grid set yet.
torax_pydantic.set_grid(self, mesh, mode='relaxed')
@pydantic.model_validator(mode='after')
def _set_grid(self) -> Self:
# Interpolated `TimeVaryingArray` objects require a mesh, only available
# once the geometry provider is built. This could be done in the before
# validator, but is harder than setting it after construction.
mesh = self.geometry.build_provider.torax_mesh
# Note that the grid could already be set, eg. if the config is serialized
# and deserialized. In this case, we do not want to overwrite it nor fail
# when trying to set it, which is why mode='relaxed'.
torax_pydantic.set_grid(self, mesh, mode='relaxed')
return self
# This is primarily used for serialization, so the importer can check which
# version of Torax was used to generate the serialized config.
@pydantic.computed_field
@property
def torax_version(self) -> str:
return version.TORAX_VERSION
@pydantic.model_validator(mode='before')
@classmethod
def _remove_version_field(cls, data: Any) -> Any:
if isinstance(data, dict):
if 'torax_version' in data:
data = {k: v for k, v in data.items() if k != 'torax_version'}
return data
def _is_nrho_updated(x: Mapping[str, Any]) -> bool:
for path in x.keys():
chunks = path.split('.')
if chunks[-1] == 'n_rho' and chunks[0] == 'geometry':
return True
return False