# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Surrogate model for ion-cyclotron resonance heating (ICRH) model."""
import dataclasses
import functools
import json
import logging
import os
from typing import Any, ClassVar, Final, Literal, Sequence
import chex
import flax.linen as nn
import jax
from jax import numpy as jnp
import jaxtyping as jt
import pydantic
from torax import array_typing
from torax import jax_utils
from torax import math_utils
from torax import state
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.physics import collisions
from torax.sources import base
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_profiles
from torax.torax_pydantic import torax_pydantic
import typing_extensions
# Internal import.
# Default value for the model function to be used for the ion cyclotron
# source. This is also used as an identifier for the model function in
# the default source config for Pydantic to "discriminate" against.
DEFAULT_MODEL_FUNCTION_NAME: str = 'icrh_model_func'
# Environment variable for the TORIC NN model. Used if the model path
# is not set in the config.
_MODEL_PATH_ENV_VAR: Final[str] = 'TORIC_NN_MODEL_PATH'
# If no path is set in either the config or the environment variable, use
# this path.
_DEFAULT_MODEL_PATH = '~/toric_surrogate/TORIC_MLP_v1/toricnn.json'
_TORIC_GRID_SIZE = 297
_HELIUM3_ID = 'He3'
_TRITIUM_SECOND_HARMONIC_ID = '2T'
_ELECTRON_ID = 'e'
def _get_default_model_path() -> str:
return os.environ.get(_MODEL_PATH_ENV_VAR, _DEFAULT_MODEL_PATH)
def _from_json(json_file) -> dict[str, Any]:
"""Load the model config and weights from a JSON file."""
if not os.path.exists(json_file):
raise FileNotFoundError(f'Model file {json_file} does not exist.')
with open(json_file) as file_:
model_dict = json.load(file_)
return model_dict
# pylint: disable=invalid-name
# Many of the variables below are named to match the physics quantities
# as defined by the TORIC ICRF solver, so we keep their naming for consistency.
[docs]
@chex.dataclass(frozen=True)
class ToricNNOutputs:
"""Outputs from the ToricNN model."""
# Power deposition on helium-3 in MW/m^3/MW_{abs}.
power_deposition_He3: array_typing.ArrayFloat
# Power deposition on tritium (second harmonic) in MW/m^3/MW_{abs}.
power_deposition_2T: array_typing.ArrayFloat
# Power deposition on electrons in MW/m^3/MW_{abs}.
power_deposition_e: array_typing.ArrayFloat
class _ToricNN(nn.Module):
"""Surrogate heating model trained on TORIC ICRF solver simulation.
This model takes input parameters from the `ToricNNInputs` class and outputs
power deposition profiles for helium-3, tritium (second harmonic) and
electrons on a radial grid.
This Flax module is not intended to be used directly but rather through the
`ToricNNWrapper` class.
The modelling approach is described in:
https://iopscience.iop.org/article/10.1088/1741-4326/ad645d/pdf. The model
is trained on regression outputs from the TORIC ICRF solver. PCA is applied
to the outputs of the solver to reduce the dimensionality of the model.
The structure of the model consistents of:
- Scaling and normalisation of the input parameters.
- An MLP transforming the scaled inputs.
- A projection back to true values using the PCA coefficients.
"""
# Hidden layer sizes for the MLP.
hidden_sizes: Sequence[int]
# Number of PCA coefficients used by ToricNN.
pca_coeffs: int
# Input dimensionality of the ToricNN model.
input_dim: int
# Number of radial nodes in output of the ToricNN model.
radial_nodes: int
def setup(self):
"""Setup the parameters of the ToricNN model."""
self.scaler_mean = self.param(
'scaler_mean',
jax.random.normal,
(self.input_dim,),
)
self.scaler_scale = self.param(
'scaler_scale',
jax.random.normal,
(self.input_dim,),
)
self.pca_components = self.param(
'pca_components',
jax.random.normal,
(
self.pca_coeffs,
self.radial_nodes,
),
)
self.pca_mean = self.param(
'pca_mean',
jax.random.normal,
(self.radial_nodes,),
)
@nn.compact
def __call__(
self,
x: jt.Float32[jt.Array, 'B* {self.input_dim}'],
) -> jt.Float32[jt.Array, 'B* {self.radial_nodes}']:
"""Run a forward pass of the ToricNN model."""
# Scale and normalise inputs.
x = (x - self.scaler_mean) / self.scaler_scale
# MLP.
for hidden_size in self.hidden_sizes:
x = nn.Dense(
hidden_size,
)(x)
x = nn.relu(x)
x = nn.Dense(
self.pca_coeffs,
)(x)
x = x @ self.pca_components + self.pca_mean # Project back to true values.
x = x * (x > 0) # Eliminate non-physical values for power deposition.
return x
[docs]
class ToricNNWrapper:
"""Wrapper for the Toric NN model.
This wrapper is currently for a SPARC-specific ion cyclotron resosonanc
heating scheme.
TODO(b/378072116): Make the wrapper more general to work with other ICRH
schemes and surrogate models.
This wrapper is the main interface for interacting with the Toric NN model.
for making predictions of heating power deposition profiles given
`ToricNNInputs`.
The wrapper constructs 3 separate instances of the `_ToricNN` class, one for
each simulated output (Helium-3, 2nd-harmonic tritium and electrons).
"""
def __init__(self, path: str | None = None):
if path is None:
path = _get_default_model_path()
self._path = path
logging.info('Loading ToricNN model from %s', path)
model_config = _from_json(path)
self.model_config = model_config
self._params = {}
self.power_deposition_network = self._load_network()
self.power_deposition_He3_params = self._load_params(_HELIUM3_ID)
self.power_deposition_2T_params = self._load_params(
_TRITIUM_SECOND_HARMONIC_ID
)
self.power_deposition_e_params = self._load_params(_ELECTRON_ID)
logging.info('Loaded ToricNN model from %s', path)
def _load_network(self) -> _ToricNN:
return _ToricNN(
hidden_sizes=self.model_config['hidden_sizes'],
pca_coeffs=self.model_config['pca_coeffs'],
input_dim=self.model_config['input_dim'],
radial_nodes=self.model_config['radial_nodes'],
)
def _load_params(self, network_name: str) -> dict[str, Any]:
"""Load a ToricNN network and its parameters."""
params = {}
params['params'] = self.model_config[f'{network_name}']
for i in range(len(self.model_config['hidden_sizes']) + 1):
params['params'][f'Dense_{i}']['kernel'] = jnp.array(
self.model_config[f'{network_name}'][f'Dense_{i}']['kernel']
)
params['params'][f'Dense_{i}']['bias'] = jnp.array(
self.model_config[f'{network_name}'][f'Dense_{i}']['bias']
)
params['params']['pca_components'] = jnp.array(
self.model_config[f'{network_name}']['pca_components']
)
params['params']['pca_mean'] = jnp.array(
self.model_config[f'{network_name}']['pca_mean']
)
params['params']['scaler_mean'] = jnp.array(
self.model_config[f'{network_name}']['scaler_mean']
)
params['params']['scaler_scale'] = jnp.array(
self.model_config[f'{network_name}']['scaler_scale']
)
return params
def __hash__(self) -> int:
return hash(self._path)
def __eq__(self, other: typing_extensions.Self) -> bool:
return isinstance(other, ToricNNWrapper)
@functools.partial(jax_utils.jit, static_argnames='toric_nn')
def _toric_nn_predict(
toric_nn: ToricNNWrapper,
inputs: ToricNNInputs,
) -> ToricNNOutputs:
"""Make a prediction given the inputs."""
inputs = jnp.array(
[
inputs.frequency,
inputs.volume_average_temperature,
inputs.volume_average_density,
inputs.minority_concentration,
inputs.gap_inner,
inputs.gap_outer,
inputs.z0,
inputs.temperature_peaking_factor,
inputs.density_peaking_factor,
inputs.B0,
],
dtype=jax_utils.get_dtype(),
)
outputs_He3 = toric_nn.power_deposition_network.apply(
toric_nn.power_deposition_He3_params, inputs
)
outputs_2T = toric_nn.power_deposition_network.apply(
toric_nn.power_deposition_2T_params, inputs
)
outputs_e = toric_nn.power_deposition_network.apply(
toric_nn.power_deposition_e_params, inputs
)
return ToricNNOutputs(
power_deposition_He3=outputs_He3,
power_deposition_2T=outputs_2T,
power_deposition_e=outputs_e,
)
[docs]
@chex.dataclass(frozen=True)
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
frequency: array_typing.ScalarFloat
minority_concentration: array_typing.ScalarFloat
Ptot: array_typing.ScalarFloat
absorption_fraction: array_typing.ScalarFloat
wall_inner: float
wall_outer: float
def _helium3_tail_temperature(
power_deposition_he3: jax.Array,
core_profiles: state.CoreProfiles,
minority_concentration: float,
Ptot: float,
) -> jax.Array:
"""Use a "Stix distribution" to estimate the tail temperature of He3."""
helium3_mass = 3.016
helium3_charge = 2
helium3_fraction = minority_concentration / 100 # Min conc provided in %.
absorbed_power_density = power_deposition_he3 * Ptot
ne20 = core_profiles.ne.value * core_profiles.nref / 1e20
# Use a "Stix distribution" [Stix, Nuc. Fus. 1975] to model the non-thermal
# He3 distribution based on an analytic solution to the FP equation.
xi = (
0.24
* jnp.sqrt(core_profiles.temp_el.value)
* helium3_mass
* absorbed_power_density
) / (ne20**2 * helium3_charge**2 * helium3_fraction)
return core_profiles.temp_el.value * (1 + xi)
[docs]
def icrh_model_func(
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
toric_nn: ToricNNWrapper,
) -> tuple[chex.Array, ...]:
"""Compute ion/electron heat source terms."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
# Construct inputs for ToricNN.
volume_average_temperature = math_utils.volume_average(
core_profiles.temp_el.value, geo
)
volume_average_density = math_utils.volume_average(
core_profiles.ne.value, geo
)
# Peaking factors are core w.r.t volume averages.
temperature_peaking_factor = (
core_profiles.temp_el.value[0] / volume_average_temperature
)
density_peaking_factor = core_profiles.ne.value[0] / volume_average_density
Router = geo.Rmaj + geo.Rmin
Rinner = geo.Rmaj - geo.Rmin
# Assumption: inner and outer gaps are not functions of z0.
# This is a good assumption for the inner gap but perhaps less good for the
# outer gap where there is significant curvature to the outer limiter.
gap_inner = Rinner - dynamic_source_runtime_params.wall_inner
gap_outer = dynamic_source_runtime_params.wall_outer - Router
toric_inputs = ToricNNInputs(
frequency=dynamic_source_runtime_params.frequency,
volume_average_temperature=volume_average_temperature,
volume_average_density=volume_average_density,
minority_concentration=dynamic_source_runtime_params.minority_concentration,
gap_inner=gap_inner,
gap_outer=gap_outer,
z0=geo.z_magnetic_axis(),
temperature_peaking_factor=temperature_peaking_factor,
density_peaking_factor=density_peaking_factor,
B0=geo.B0,
)
toric_nn_outputs = _toric_nn_predict(toric_nn, toric_inputs)
toric_grid = jnp.linspace(0.0, 1.0, _TORIC_GRID_SIZE)
# Ideally total ICRH power should equal one but normalise if not.
power_deposition_he3 = jnp.interp(
geo.torax_mesh.cell_centers,
toric_grid,
toric_nn_outputs.power_deposition_He3,
)
power_deposition_e = jnp.interp(
geo.torax_mesh.cell_centers,
toric_grid,
toric_nn_outputs.power_deposition_e,
)
power_deposition_2T = jnp.interp(
geo.torax_mesh.cell_centers,
toric_grid,
toric_nn_outputs.power_deposition_2T,
)
power_deposition_all = (
power_deposition_2T + power_deposition_e + power_deposition_he3
)
total_power_deposition = math_utils.volume_integration(
power_deposition_all, geo
)
power_deposition_he3 /= total_power_deposition
power_deposition_e /= total_power_deposition
power_deposition_2T /= total_power_deposition
# For helium-3 we use a "Stix distribution" to model the non-thermal He3 tail.
helium3_birth_energy = _helium3_tail_temperature(
power_deposition_he3,
core_profiles,
dynamic_source_runtime_params.minority_concentration,
dynamic_source_runtime_params.Ptot / 1e6, # required in MW.
)
helium3_mass = 3.016
frac_ion_heating = collisions.fast_ion_fractional_heating_formula(
helium3_birth_energy,
core_profiles.temp_el.value,
helium3_mass,
)
absorbed_power = (
dynamic_source_runtime_params.Ptot
* dynamic_source_runtime_params.absorption_fraction
)
source_ion = power_deposition_he3 * frac_ion_heating * absorbed_power
source_el = power_deposition_he3 * (1 - frac_ion_heating) * absorbed_power
# Assume that all the power from the electron power profile goes to electrons.
source_el += power_deposition_e * absorbed_power
# Assume that all the power from the tritium power profile goes to ions.
source_ion += power_deposition_2T * absorbed_power
return (source_ion, source_el)
[docs]
@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class IonCyclotronSource(source.Source):
"""Ion cyclotron source with surrogate model."""
SOURCE_NAME: ClassVar[str] = 'ion_cyclotron_source'
@property
def source_name(self) -> str:
return self.SOURCE_NAME
@property
def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:
return (
source.AffectedCoreProfile.TEMP_ION,
source.AffectedCoreProfile.TEMP_EL,
)
# Cache the result of this function to avoid re-creating the partial function
# every time it is called and ensure we hit the same JAX compile cache (as
# model_func) is part of the key.
# maxsize=1 is sufficient as the ToricNNWrapper only changes if a new path
# is provided. This is not expected to happen very often.
@functools.lru_cache(maxsize=1)
def _icrh_model_func_with_toric_nn(
toric_nn: ToricNNWrapper,
) -> source.SourceProfileFunction:
"""Returns a function that computes the ICRH source terms given a ToricNN."""
return functools.partial(
icrh_model_func,
toric_nn=toric_nn,
)
[docs]
class IonCyclotronSourceConfig(base.SourceModelBase):
"""Configuration for the IonCyclotronSource.
Attributes:
wall_inner: Inner radial location of first wall at plasma midplane level
[m].
wall_outer: Outer radial location of first wall at plasma midplane level
[m].
frequency: ICRF wave frequency [Hz].
minority_concentration: He3 minority concentration relative to the electron
density in %.
Ptot: Total heating power [W].
absorption_fraction: Fraction of absorbed power.
"""
model_function_name: Literal['icrh_model_func'] = 'icrh_model_func'
wall_inner: torax_pydantic.Meter = 1.24
wall_outer: torax_pydantic.Meter = 2.43
frequency: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(
120e6
)
minority_concentration: torax_pydantic.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(3.0)
)
Ptot: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(10e6)
absorption_fraction: torax_pydantic.PositiveTimeVaryingScalar = (
torax_pydantic.ValidatedDefault(1.0)
)
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED
@pydantic.model_validator(mode='after')
def _load_toric_nn(self) -> typing_extensions.Self:
self._toric_nn = ToricNNWrapper()
return self
@property
def model_func(self) -> source.SourceProfileFunction:
return _icrh_model_func_with_toric_nn(self._toric_nn)
[docs]
def build_dynamic_params(
self,
t: chex.Numeric,
) -> DynamicRuntimeParams:
return DynamicRuntimeParams(
prescribed_values=tuple(
[v.get_value(t) for v in self.prescribed_values]
),
wall_inner=self.wall_inner,
wall_outer=self.wall_outer,
frequency=self.frequency.get_value(t),
minority_concentration=self.minority_concentration.get_value(t),
Ptot=self.Ptot.get_value(t),
absorption_fraction=self.absorption_fraction.get_value(t),
)
[docs]
def build_source(self) -> IonCyclotronSource:
return IonCyclotronSource(model_func=self.model_func)