Source code for torax.transport_model.pydantic_model

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pydantic config for Transport models."""

import copy
import dataclasses
import os
from typing import Any, Final, Literal, Union

import chex
import pydantic
from torax.torax_pydantic import interpolated_param_1d
from torax.torax_pydantic import torax_pydantic
from torax.transport_model import bohm_gyrobohm
from torax.transport_model import constant
from torax.transport_model import critical_gradient
from torax.transport_model import pydantic_model_base
from torax.transport_model import qlknn_10d
from torax.transport_model import qlknn_transport_model


# Environment variable for the QLKNN model. Used if the model path
# is not set in the config.
_MODEL_PATH_ENV_VAR: Final[str] = 'TORAX_QLKNN_MODEL_PATH'


# pylint: disable=invalid-name
[docs] class QLKNNTransportModel(pydantic_model_base.TransportBase): """Model for the QLKNN transport model. To determine which model to load, TORAX uses the following logic: * If `model_path` is provided, then we load the model from this path. * Otherwise, if the `TORAX_QLKNN_MODEL_PATH` environment variable is set, then we load the model from this path. * Otherwise, if `model_name` is provided, we load that model from registered models in the `fusion_surrogates` library. * If `model_name` is not set either, we load the default QLKNN model from `fusion_surrogates` (currently `QLKNN_7_11`). It is recommended to not set `model_name`, `TORAX_QLKNN_MODEL_PATH` or `model_path` to use the default QLKNN model. Attributes: transport_model: The transport model to use. Hardcoded to 'qlknn'. model_path: Path to the model. Takes precedence over `model_name` and `TORAX_QLKNN_MODEL_PATH`. model_name: Name of the model to use. Used to select a model from the `fusion_surrogates` library. include_ITG: Whether to include ITG modes. include_TEM: Whether to include TEM modes. include_ETG: Whether to include ETG modes. ITG_flux_ratio_correction: Correction factor for ITG electron heat flux. ETG_correction_factor: Correction factor for ETG electron heat flux. https://gitlab.com/qualikiz-group/QuaLiKiz/-/commit/5bcd3161c1b08e0272ab3c9412fec7f9345a2eef clip_inputs: Whether to clip inputs within desired margin of the QLKNN training set boundaries. clip_margin: Margin to clip inputs within desired margin of the QLKNN training set boundaries. coll_mult: Collisionality multiplier. avoid_big_negative_s: Ensure that smag - alpha > -0.2 always, to compensate for no slab modes. smag_alpha_correction: Reduce magnetic shear by 0.5*alpha to capture main impact of alpha. q_sawtooth_proxy: If q < 1, modify input q and smag as if q~1 as if there are sawteeth. DVeff: Effective D / effective V approach for particle transport. An_min: Minimum |R/Lne| below which effective V is used instead of effective D. """ transport_model: Literal['qlknn'] = 'qlknn' model_path: str = '' model_name: str = '' include_ITG: bool = True include_TEM: bool = True include_ETG: bool = True ITG_flux_ratio_correction: float = 1.0 ETG_correction_factor: float = 1.0 / 3.0 clip_inputs: bool = False clip_margin: float = 0.95 coll_mult: float = 1.0 avoid_big_negative_s: bool = True smag_alpha_correction: bool = True q_sawtooth_proxy: bool = True DVeff: bool = False An_min: pydantic.PositiveFloat = 0.05 @pydantic.model_validator(mode='before') @classmethod def _conform_data(cls, data: dict[str, Any]) -> dict[str, Any]: data = copy.deepcopy(data) # Get the model path and update the config with the final path. model_path = data.get('model_path', os.environ.get(_MODEL_PATH_ENV_VAR, '')) model = qlknn_transport_model.get_model( path=model_path, name=data.get('model_name', '') ) # Update name from the loaded model. data['model_name'] = model.name if data['model_name'] == qlknn_10d.QLKNN10D_NAME: if 'coll_mult' not in data: # Correction factor to a more recent QLK collision operator. data['coll_mult'] = 0.25 if 'ITG_flux_ratio_correction' not in data: # The QLK version this specific QLKNN was trained on tends to # underpredict ITG electron heat flux in shaped, high-beta scenarios. data['ITG_flux_ratio_correction'] = 2.0 else: if 'smoothing_sigma' not in data: data['smoothing_sigma'] = 0.1 return data
[docs] def build_transport_model(self) -> qlknn_transport_model.QLKNNTransportModel: return qlknn_transport_model.QLKNNTransportModel( path=self.model_path, name=self.model_name )
def build_dynamic_params( self, t: chex.Numeric ) -> qlknn_transport_model.DynamicRuntimeParams: base_kwargs = dataclasses.asdict(super().build_dynamic_params(t)) return qlknn_transport_model.DynamicRuntimeParams( include_ITG=self.include_ITG, include_TEM=self.include_TEM, include_ETG=self.include_ETG, ITG_flux_ratio_correction=self.ITG_flux_ratio_correction, ETG_correction_factor=self.ETG_correction_factor, clip_inputs=self.clip_inputs, clip_margin=self.clip_margin, coll_mult=self.coll_mult, avoid_big_negative_s=self.avoid_big_negative_s, smag_alpha_correction=self.smag_alpha_correction, q_sawtooth_proxy=self.q_sawtooth_proxy, DVeff=self.DVeff, An_min=self.An_min, **base_kwargs, )
[docs] class ConstantTransportModel(pydantic_model_base.TransportBase): """Model for the Constant transport model. Attributes: transport_model: The transport model to use. Hardcoded to 'constant'. chii_const: coefficient in ion heat equation diffusion term in m^2/s. chie_const: coefficient in electron heat equation diffusion term in m^2/s. De_const: diffusion coefficient in electron density equation in m^2/s. Ve_const: convection coefficient in electron density equation in m^2/s. """ transport_model: Literal['constant'] = 'constant' chii_const: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) chie_const: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) De_const: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) Ve_const: interpolated_param_1d.TimeVaryingScalar = ( torax_pydantic.ValidatedDefault(-0.33) )
[docs] def build_transport_model(self) -> constant.ConstantTransportModel: return constant.ConstantTransportModel()
def build_dynamic_params( self, t: chex.Numeric ) -> constant.DynamicRuntimeParams: base_kwargs = dataclasses.asdict(super().build_dynamic_params(t)) return constant.DynamicRuntimeParams( chii_const=self.chii_const.get_value(t), chie_const=self.chie_const.get_value(t), De_const=self.De_const.get_value(t), Ve_const=self.Ve_const.get_value(t), **base_kwargs, )
[docs] class CriticalGradientTransportModel(pydantic_model_base.TransportBase): """Model for the Critical Gradient transport model. Attributes: transport_model: The transport model to use. Hardcoded to 'CGM'. alpha: Exponent of chi power law: chi ∝ (R/LTi - R/LTi_crit)^alpha. chistiff: Stiffness parameter. chiei_ratio: Ratio of electron to ion heat transport coefficient (ion higher for ITG). chi_D_ratio: Ratio of electron particle to ion heat transport coefficient. VR_D_ratio: Ratio of major radius * electron particle convection, to electron diffusion. Sets the value of electron particle convection in the model. """ transport_model: Literal['CGM'] = 'CGM' alpha: float = 2.0 chistiff: float = 2.0 chiei_ratio: interpolated_param_1d.TimeVaryingScalar = ( torax_pydantic.ValidatedDefault(2.0) ) chi_D_ratio: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(5.0) ) VR_D_ratio: interpolated_param_1d.TimeVaryingScalar = ( torax_pydantic.ValidatedDefault(0.0) )
[docs] def build_transport_model( self, ) -> critical_gradient.CriticalGradientTransportModel: return critical_gradient.CriticalGradientTransportModel()
def build_dynamic_params( self, t: chex.Numeric ) -> critical_gradient.DynamicRuntimeParams: base_kwargs = dataclasses.asdict(super().build_dynamic_params(t)) return critical_gradient.DynamicRuntimeParams( alpha=self.alpha, chistiff=self.chistiff, chiei_ratio=self.chiei_ratio.get_value(t), chi_D_ratio=self.chi_D_ratio.get_value(t), VR_D_ratio=self.VR_D_ratio.get_value(t), **base_kwargs, )
[docs] class BohmGyroBohmTransportModel(pydantic_model_base.TransportBase): """Model for the Bohm + Gyro-Bohm transport model. Attributes: transport_model: The transport model to use. Hardcoded to 'bohm-gyrobohm'. chi_e_bohm_coeff: Prefactor for Bohm term for electron heat conductivity. chi_e_gyrobohm_coeff: Prefactor for GyroBohm term for electron heat conductivity. chi_i_bohm_coeff: Prefactor for Bohm term for ion heat conductivity. chi_i_gyrobohm_coeff: Prefactor for GyroBohm term for ion heat conductivity. chi_e_bohm_multiplier: Multiplier for chi_e_bohm_coeff. Intended for user-friendly default modification. chi_e_gyrobohm_multiplier: Multiplier for chi_e_gyrobohm_coeff. Intended for user-friendly default modification. chi_i_bohm_multiplier: Multiplier for chi_i_bohm_coeff. Intended for user-friendly default modification. chi_i_gyrobohm_multiplier: Multiplier for chi_i_gyrobohm_coeff. Intended for user-friendly default modification. d_face_c1: Constant for the electron diffusivity weighting factor. d_face_c2: Constant for the electron diffusivity weighting factor. v_face_coeff: Proportionality factor between convectivity and diffusivity. """ transport_model: Literal['bohm-gyrobohm'] = 'bohm-gyrobohm' chi_e_bohm_coeff: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(8e-5) ) chi_e_gyrobohm_coeff: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(5e-6) ) chi_i_bohm_coeff: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(8e-5) ) chi_i_gyrobohm_coeff: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(5e-6) ) chi_e_bohm_multiplier: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) chi_e_gyrobohm_multiplier: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) chi_i_bohm_multiplier: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) chi_i_gyrobohm_multiplier: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) d_face_c1: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(1.0) ) d_face_c2: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(0.3) ) v_face_coeff: interpolated_param_1d.TimeVaryingScalar = ( torax_pydantic.ValidatedDefault(-0.1) )
[docs] def build_transport_model( self, ) -> bohm_gyrobohm.BohmGyroBohmTransportModel: return bohm_gyrobohm.BohmGyroBohmTransportModel()
def build_dynamic_params( self, t: chex.Numeric ) -> bohm_gyrobohm.DynamicRuntimeParams: base_kwargs = dataclasses.asdict(super().build_dynamic_params(t)) return bohm_gyrobohm.DynamicRuntimeParams( chi_e_bohm_coeff=self.chi_e_bohm_coeff.get_value(t), chi_e_gyrobohm_coeff=self.chi_e_gyrobohm_coeff.get_value(t), chi_i_bohm_coeff=self.chi_i_bohm_coeff.get_value(t), chi_i_gyrobohm_coeff=self.chi_i_gyrobohm_coeff.get_value(t), chi_e_bohm_multiplier=self.chi_e_bohm_multiplier.get_value(t), chi_e_gyrobohm_multiplier=self.chi_e_gyrobohm_multiplier.get_value(t), chi_i_bohm_multiplier=self.chi_i_bohm_multiplier.get_value(t), chi_i_gyrobohm_multiplier=self.chi_i_gyrobohm_multiplier.get_value(t), d_face_c1=self.d_face_c1.get_value(t), d_face_c2=self.d_face_c2.get_value(t), v_face_coeff=self.v_face_coeff.get_value(t), **base_kwargs, )
try: # pylint: disable=g-import-not-at-top from torax.transport_model import qualikiz_transport_model # pylint: enable=g-import-not-at-top TransportConfig = Union[ QLKNNTransportModel, ConstantTransportModel, CriticalGradientTransportModel, BohmGyroBohmTransportModel, qualikiz_transport_model.QualikizTransportModelConfig, ] except ImportError: TransportConfig = Union[ QLKNNTransportModel, ConstantTransportModel, CriticalGradientTransportModel, BohmGyroBohmTransportModel, ]