How to integrate new models
TORAX has a modular design which supports easy coupling of new physics models such as sources, transport models, pedestal models, etc.
TORAX provides a public API for registering custom transport, pedestal, and source models. Once registered, custom models can be configured via TORAX config files or dictionaries just like the built-in models.
Registering a custom transport model
To integrate a custom transport model, you need to:
Define a transport model class that implements the physics.
Define a pydantic config class for your model.
Register the config class with TORAX.
Step 1: Implement the transport model
Create a frozen dataclass that inherits from torax.transport.TransportModel
and implements the call_implementation method. This method receives the
current simulation state and must return a torax.transport.TurbulentTransport
object containing the computed transport coefficients on the face grid.
import dataclasses
import jax.numpy as jnp
import torax
from torax import transport
@dataclasses.dataclass(frozen=True, eq=False)
class MyTransportModel(transport.TransportModel):
"""Custom transport model."""
def call_implementation(
self,
transport_runtime_params: transport.RuntimeParams,
runtime_params: torax.RuntimeParams,
geo: torax.Geometry,
core_profiles: torax.CoreProfiles,
pedestal_model_outputs: torax.PedestalModelOutput,
) -> transport.TurbulentTransport:
# Implement your transport model here.
# Must return a TurbulentTransport with at least the four required
# fields: chi_face_ion, chi_face_el, d_face_el, v_face_el.
chi_i = jnp.ones_like(geo.rho_face_norm) * 2.0
chi_e = jnp.ones_like(geo.rho_face_norm) * 1.5
d_e = jnp.ones_like(geo.rho_face_norm) * 0.5
v_e = jnp.zeros_like(geo.rho_face_norm)
return transport.TurbulentTransport(
chi_face_ion=chi_i,
chi_face_el=chi_e,
d_face_el=d_e,
v_face_el=v_e,
)
Step 2: Define the pydantic config
Create a pydantic config class that inherits from
torax.transport.TransportBase and implements the build_transport_model
method. The config class must have a model_name field with a unique
Literal type that identifies your model.
from typing import Annotated, Literal
class MyTransportConfig(transport.TransportBase):
"""Pydantic config for MyTransportModel."""
model_name: Annotated[
Literal['my_transport'], torax.JAX_STATIC
] = 'my_transport'
def build_transport_model(self) -> MyTransportModel:
return MyTransportModel()
Step 3: Register the model
Call torax.transport.register_transport_model with your pydantic config
class. This must be done at module level, before any TORAX config is built.
transport.register_transport_model(MyTransportConfig)
Using the registered model
Once registered, the model can be used in a TORAX config by setting the
transport.model_name field to the model name you defined:
config = {
...
'transport': {
'model_name': 'my_transport',
},
...
}
torax_config = torax.ToraxConfig.from_dict(config)
torax.run_simulation(torax_config)
Registering a custom pedestal model
To integrate a custom pedestal model, you need to:
Define a pedestal model class that implements the physics.
Define a pydantic config class for your model.
Register the config class with TORAX.
Step 1: Implement the pedestal model
Create a frozen dataclass that inherits from torax.pedestal.PedestalModel
and implements the _call_implementation method. This method must return a
torax.pedestal.PedestalModelOutput with the pedestal properties.
The PedestalModel base class requires formation_model and
saturation_model fields (used in ADAPTIVE_TRANSPORT mode). These should
be passed through from the pydantic config’s build_pedestal_model method.
import dataclasses
import jax.numpy as jnp
import torax
from torax import pedestal
@dataclasses.dataclass(frozen=True, eq=False)
class MyPedestalModel(pedestal.PedestalModel):
"""Custom pedestal model."""
def _call_implementation(
self,
runtime_params: torax.RuntimeParams,
geo: torax.Geometry,
core_profiles: torax.CoreProfiles,
) -> pedestal.PedestalModelOutput:
# Implement your pedestal model here.
return pedestal.PedestalModelOutput(
rho_norm_ped_top=jnp.array(0.9),
rho_norm_ped_top_idx=jnp.abs(geo.rho_norm - 0.9).argmin(),
T_i_ped=jnp.array(5.0),
T_e_ped=jnp.array(5.0),
n_e_ped=jnp.array(0.7e20),
)
Step 2: Define the pydantic config
Create a pydantic config class that inherits from
torax.pedestal.BasePedestal and implements the build_pedestal_model
method. The config class must have a model_name field with a unique
Literal type.
You may also override build_runtime_params if your model requires
additional runtime parameters beyond the base pedestal parameters.
from typing import Annotated, Literal
class MyPedestalConfig(pedestal.BasePedestal):
"""Pydantic config for MyPedestalModel."""
model_name: Annotated[
Literal['my_pedestal'], torax.JAX_STATIC
] = 'my_pedestal'
def build_pedestal_model(self) -> MyPedestalModel:
return MyPedestalModel(
formation_model=self.formation_model.build_formation_model(),
saturation_model=self.saturation_model.build_saturation_model(),
)
def build_runtime_params(
self, t,
) -> pedestal.RuntimeParams:
return pedestal.RuntimeParams(
set_pedestal=self.set_pedestal.get_value(t),
mode=self.mode,
formation=self.formation_model.build_runtime_params(t),
saturation=self.saturation_model.build_runtime_params(t),
chi_max=self.chi_max.get_value(t),
D_e_max=self.D_e_max.get_value(t),
V_e_max=self.V_e_max.get_value(t),
V_e_min=self.V_e_min.get_value(t),
pedestal_top_smoothing_width=(
self.pedestal_top_smoothing_width.get_value(t)
),
)
Step 3: Register the model
Call torax.pedestal.register_pedestal_model with your pydantic config
class. This must be done at module level, before any TORAX config is built.
pedestal.register_pedestal_model(MyPedestalConfig)
Using the registered model
Once registered, the model can be used in a TORAX config:
config = {
...
'pedestal': {
'model_name': 'my_pedestal',
'set_pedestal': True,
},
...
}
torax_config = torax.ToraxConfig.from_dict(config)
torax.run_simulation(torax_config)
Registering a custom source model
To integrate a custom source model (e.g. a new heat source, particle source, or current source), you need to:
Define a model function that computes the source profile.
Define a pydantic config class for your model.
Register the config class with TORAX against a specific source name.
Unlike transport and pedestal models, source models are registered against a
specific source name (e.g. 'gas_puff', 'fusion', 'generic_heat',
etc.). This allows multiple source model implementations to exist for the same
physical source. The two special sources 'qei' (ion-electron heat exchange)
and 'j_bootstrap' (bootstrap current) do not support custom registration.
Step 1: Implement the model function
Define a function that matches the torax.sources.SourceProfileFunction
protocol. This function receives the simulation state and must return a tuple
of source profile arrays (one per affected core profile). The order of the
profiles in the tuple must match the order of the affected core profiles for the
source being registered against (e.g. for generic_heat, the tuple must be
(ion heat, electron heat)).
import jax.numpy as jnp
import torax
from torax import sources
def my_heat_source(
runtime_params: torax.RuntimeParams,
geo: torax.Geometry,
source_name: str,
core_profiles: torax.CoreProfiles,
calculated_source_profiles: sources.SourceProfiles | None,
unused_conductivity,
) -> tuple[jnp.ndarray, ...]:
"""Custom heat source model."""
# Return a tuple with one element per affected core profile.
# For a source affecting TEMP_ION and TEMP_EL, return two profiles.
ion_heat = jnp.ones_like(geo.rho_norm) * 1e6
el_heat = jnp.ones_like(geo.rho_norm) * 0.5e6
return (ion_heat, el_heat)
Step 2: Define the pydantic config
Create a pydantic config class that inherits from
torax.sources.SourceModelBase and implements three required methods:
model_func(property): returns the model function.build_source: returns theSourceinstance.build_runtime_params: returns source-specificRuntimeParams.
The config class must have a model_name field with a unique Literal
type that identifies your model. This name must be different from the default
model name for the source you are registering against.
import dataclasses
from typing import Literal
import chex
import jax
from torax import sources
from torax._src.sources import generic_ion_el_heat_source as heat_source_lib
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class MyRuntimeParams(sources.RuntimeParams):
"""Custom runtime params with an extra parameter."""
scaling_factor: float
class MyHeatSourceConfig(sources.SourceModelBase):
"""Pydantic config for my custom heat source."""
model_name: Literal['my_heat_model'] = 'my_heat_model'
scaling_factor: float = 1.0
@property
def model_func(self) -> sources.SourceProfileFunction:
return my_heat_source
def build_source(self) -> sources.Source:
return heat_source_lib.GenericIonElectronHeatSource(
model_func=self.model_func
)
def build_runtime_params(
self, t: chex.Numeric,
) -> MyRuntimeParams:
return MyRuntimeParams(
scaling_factor=self.scaling_factor,
prescribed_values=tuple(
[v.get_value(t) for v in self.prescribed_values]
),
mode=self.mode,
is_explicit=self.is_explicit,
)
Step 3: Register the model
Call torax.sources.register_source_model_config with your pydantic config
class and the name of the source to register against. This must be done at
module level, before any TORAX config is built.
sources.register_source_model_config(MyHeatSourceConfig, 'generic_heat')
The source_name must be one of the fields in the Sources pydantic model:
bremsstrahlung, cyclotron_radiation, ecrh, fusion, gas_puff,
generic_current, generic_heat, generic_particle, icrh,
impurity_radiation, ohmic, or pellet.
If you want to register a custom implementation for a source that isn’t in this list (for example “nbi”), please reach out to the TORAX team and we will help.
Using the registered model
Once registered, the model can be used in a TORAX config by setting the
model_name field within the corresponding source:
config = {
...
'sources': {
'generic_heat': {
'model_name': 'my_heat_model',
'scaling_factor': 2.0,
},
},
...
}
torax_config = torax.ToraxConfig.from_dict(config)
torax.run_simulation(torax_config)
Model Integration Topics