torax package

Subpackages

Submodules

torax.array_typing module

Common types and helpers for using jaxtyping in TORAX.

torax.array_typing.typed(function)[source]

Helper decorator for using jaxtyping for shape and dtype checking.

Example usage: @typed def f(x: ScalarFloat) -> ScalarFloat:

Jaxtyping is enabled by default, to globally disable jaxtyping set JAXTYPING_DISABLE=True.

Parameters:

function (TypeVar(F, bound= Callable)) – The function to shape check.

Return type:

TypeVar(F, bound= Callable)

Returns:

The decorated function.

torax.conftest module

Pytest fixture for working around UnparsedFlagAccessError when running tests.

torax.constants module

Physics constants.

This module saves immutable constants used in various calculations.

class torax.constants.Constants(keV2J, mp, qe, me, epsilon0, mu0, eps)[source]

Bases: Mapping

Parameters:
items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class torax.constants.IonProperties(symbol, name, A, Z)[source]

Bases: Mapping

Properties of an ion.

symbol

The ion’s symbol.

name

The ion’s full name.

A

The ion’s atomic mass unit (amu).

Z

The ion’s atomic number.

Parameters:
items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values

torax.interpolated_param module

Classes and functions for defining interpolated parameters.

class torax.interpolated_param.InterpolatedParamBase[source]

Bases: ABC

Base class for interpolated params.

An InterpolatedParamBase child class should implement the interface defined below where, given an x-value for where to interpolate, this object returns a value. The x-value can be either a time or spatial coordinate depending on what we are interpolating over.

abstract get_value(x)[source]

Returns a value for this parameter interpolated at the given input.

Parameters:

x (Union[Array, ndarray, bool, number, float, int])

Return type:

Union[Array, ndarray, bool, number]

class torax.interpolated_param.InterpolatedVarSingleAxis(value, interpolation_mode=InterpolationMode.PIECEWISE_LINEAR, is_bool_param=False)[source]

Bases: InterpolatedParamBase

Parameter that may vary based on an input coordinate.

This class is useful for defining time-dependent runtime parameters, but can be used to define any parameters that vary across some range.

This function allows the interpolation of a 1d array xs, against either a 1d or 2d array ys. For example, xs can be time, and ys either a 1d array of scalars associated to the times in xs, or a 2d array where the index 0 in ys associates a radial array in the index 1 with the times in xs. The interpolation of the 2d array is then carried out element-wise and accelerated with vmap. Intended use of ys being a 2d array is when the radial slices on index 1 have already been interpolated onto appropriate TORAX grids, such as cell_centers, faces, or the hires grid. NOTE: this means that the 2d array should have shape (n, m) where n is the number of elements in the 1d array and m is the number of spatial grid size of the InterpolatedVar1d instance

See config.runtime_params.RuntimeParams and associated tests to see how this is used.

Parameters:
get_value(x)[source]

Returns a single value for this range at the given coordinate.

Parameters:

x (Union[Array, ndarray, bool, number, float, int])

Return type:

Union[Array, ndarray, bool, number]

property interpolation_mode: InterpolationMode

Returns the interpolation mode used by this param.

property is_bool_param: bool

Returns whether this param represents a bool.

property param: InterpolatedParamBase

Returns the JAX-friendly interpolated param used under the hood.

class torax.interpolated_param.InterpolatedVarTimeRho(values, rho_norm, time_interpolation_mode=InterpolationMode.PIECEWISE_LINEAR, rho_interpolation_mode=InterpolationMode.PIECEWISE_LINEAR)[source]

Bases: InterpolatedParamBase

Interpolates on a grid (time, rho).

This class linearly interpolates along time to provide a value at any (time, rho) pair. For time values that are outside the range of values the closest defined InterpolatedVarSingleAxis is used.

  • NOTE: We assume that rho interpolation is fixed per simulation so take this

at init and take just time at get_value.

Parameters:
get_value(x)[source]

Returns the value of this parameter interpolated at x=time.

Parameters:

x (Union[Array, ndarray, bool, number, float, int])

Return type:

Union[Array, ndarray, bool, number]

property rho_interpolation_mode: InterpolationMode

Returns the rho interpolation mode used by this param.

property time_interpolation_mode: InterpolationMode

Returns the time interpolation mode used by this param.

class torax.interpolated_param.InterpolationMode(value)[source]

Bases: Enum

Defines how to do the interpolation.

InterpolatedParams have many values to interpolate between, and this enum defines how exactly that interpolation is computed.

Assuming inputs [x_0, …, x_n] and [y_0, …, y_n], for all modes, the interpolated param outputs y_0 for any input less than x_0 and y_n for any input greater than x_n.

Options:
PIECEWISE_LINEAR: Does piecewise-linear interpolation between the values

provided. See numpy.interp for a longer description of how it works. (This uses JAX, but the behavior is the same.)

STEP: Step-function interpolation. For any input value x in the range [x_k,

x_k+1), the output will be y_k.

class torax.interpolated_param.PiecewiseLinearInterpolatedParam(xs, ys)[source]

Bases: InterpolatedParamBase

Parameter using piecewise-linear interpolation to compute its value.

Parameters:
get_value(x)[source]

Returns a value for this parameter interpolated at the given input.

Parameters:

x (Union[Array, ndarray, bool, number, float, int])

Return type:

Union[Array, ndarray, bool, number]

class torax.interpolated_param.StepInterpolatedParam(xs, ys)[source]

Bases: InterpolatedParamBase

Parameter using step interpolation to compute its value.

Parameters:
get_value(x)[source]

Returns a value for this parameter interpolated at the given input.

Parameters:

x (Union[Array, ndarray, bool, number, float, int])

Return type:

Union[Array, ndarray, bool, number]

torax.interpolated_param.convert_input_to_xs_ys(interp_input)[source]

Converts config inputs into inputs suitable for constructors.

Parameters:

interp_input (float | dict[float, float] | bool | dict[float, bool] | tuple[Union[Array, ndarray, bool, number, list[float]], Union[Array, ndarray, bool, number, list[float]]] | DataArray | tuple[float | dict[float, float] | bool | dict[float, bool] | tuple[Union[Array, ndarray, bool, number, list[float]], Union[Array, ndarray, bool, number, list[float]]] | DataArray, Literal['step', 'STEP', 'piecewise_linear', 'PIECEWISE_LINEAR']]) – The input to convert.

Return type:

tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], InterpolationMode, bool]

Returns:

A tuple of (xs, ys, interpolation_mode, is_bool_param) where xs and ys are the arrays to be used in the constructor, interpolation_mode is the interpolation mode to be used, and is_bool_param is True if the input is a bool and False otherwise.

torax.jax_utils module

Commonly repeated jax expressions.

torax.jax_utils.assert_rank(inputs, rank)[source]

Wrapper around chex.assert_rank that supports jax.stages.ArgInfo.

Parameters:
Return type:

None

torax.jax_utils.enable_errors(value)[source]

Enables / disables error_if inside a code block.

Example:

with enable_errors(False):

my_sim.run() # NaNs etc will be ignored

Parameters:

value (bool) – Sets errors_enabled to this value

Yields:

Cleanup function restoring previous value

torax.jax_utils.env_bool(name, default)[source]

Get a bool from an environment variable.

Parameters:
  • name (str) – The name of the environment variable.

  • default (bool) – The default value of the bool.

Returns:

The value of the bool.

Return type:

value

torax.jax_utils.error_if(var, cond, msg)[source]

Raises error if cond is true, and errors_enabled is True.

This is just a wrapper around equinox.error_if, gated by errors_enabled.

Parameters:
  • var (Array) – The variable to pass through.

  • cond (Array) – Boolean array, error if cond is true.

  • msg (str) – Message to print on error.

Returns:

Identity wrapper that must be used for the check to be included.

Return type:

var

torax.jax_utils.error_if_negative(var, name, to_wrap=None)[source]

Check that a variable is non-negative.

Similar to error_if_not_positive, but 0 is allowed in this function.

Parameters:
  • var (Array) – The variable to check.

  • name (str) – Name of the variable.

  • to_wrap (Optional[Array]) – If var won’t be used in your jax function, specify another variable that will be.

Returns:

Identity wrapper that must be used for the check to be included.

Return type:

var

torax.jax_utils.get_number_of_compiles(jitted_function)[source]

Helper function for debugging JAX compilation.

This counts the number of times the function has been JIT compiled. This does not include any uses of the AOT compile workflow.

Parameters:

jitted_function (Callable[..., Any]) – A function that has been wrapped with jax.jit.

Return type:

int

Returns:

The number of times the function has been compiled.

Raises:

RuntimeError – If the function does not have a _cache_size attribute.

torax.jax_utils.jit(*args, **kwargs)[source]

Calls jax.jit if TORAX_COMPILATION_ENABLED is True, otherwise no-op.

Return type:

Callable[..., Any]

torax.jax_utils.py_cond(cond, true_fun, false_fun)[source]

Pure Python implementation of jax.lax.cond.

This gives us a way to write code that could easily be changed to be Jax-compatible in the future, if we want to expand the scope of the jit compilation.

Parameters:
  • cond (bool) – The condition

  • true_fun (Callable) – Function to be called if cond==True.

  • false_fun (Callable) – Function to be called if cond==False.

Return type:

Any

Returns:

The output from either true_fun or false_fun.

torax.jax_utils.py_fori_loop(lower, upper, body_fun, init_val)[source]

Pure Python implementation of jax.lax.fori_loop.

This gives us a way to write code that could easily be changed to be Jax-compatible in the future, if we want to expand the scope of the jit compilation.

Parameters:
  • lower (int) – lower integer of loop

  • upper (int) – upper integer of loop. upper<=lower will produce no iterations.

  • body_fun (Callable[[int, TypeVar(T)], TypeVar(T)]) – function of type a -> a.

  • init_val (TypeVar(T)) – value of type a, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value.

Return type:

TypeVar(T)

Returns:

The output from the final iteration of body_fun, of type a.

torax.jax_utils.py_while(cond_fun, body_fun, init_val)[source]

Pure Python implementation of jax.lax.while_loop.

This gives us a way to write code that could easily be changed to be Jax-compatible in the future (if we want to compute its gradient or compile it, etc.) without having to pay the high compile time cost of jax.lax.while_loop.

Parameters:
  • cond_fun (Callable[[TypeVar(T)], Any]) – function of type a -> Bool.

  • body_fun (Callable[[TypeVar(T)], TypeVar(T)]) – function of type a -> a.

  • init_val (TypeVar(T)) – value of type a, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value.

Return type:

TypeVar(T)

Returns:

The output from the final iteration of body_fun, of type a.

torax.math_utils module

Math operations.

Math operations that are needed for Torax, but are not specific to plasma physics or differential equation solvers.

class torax.math_utils.IntegralPreservationQuantity(value)[source]

Bases: Enum

The quantity to preserve the integral of when converting to face values.

torax.math_utils.area_integration(value, geo)[source]

Calculates integral of value using an area metric.

Parameters:
  • value (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • geo (Geometry)

Return type:

Union[Float[Array, ''], Float[ndarray, ''], number, float]

torax.math_utils.cell_integration(x, geo)[source]

Integrate a value x over the rhon grid.

Cell variables in TORAX are defined as the average of the face values. This method integrates that face value over the rhon grid implicitly using the trapezium rule to sum the averaged face values by the face grid spacing.

Parameters:
  • x (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']]) – The cell averaged value to integrate.

  • geo (Geometry) – The geometry instance.

Returns:

\(\int_0^1 x_{face} d\hat{rho}\)

Return type:

Integration over the rhon grid

torax.math_utils.cell_to_face(cell_values, geo, preserved_quantity=IntegralPreservationQuantity.VALUE)[source]

Convert cell values to face values.

We make four assumptions: 1) Inner face values are the average of neighbouring cells. 2) The left most face value is linearly extrapolated from the left most cell values. 3) The transformation from cell to face is integration preserving. 4) The cell spacing is constant.

Parameters:
  • cell_values (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']]) – Values defined on the TORAX cell grid.

  • geo (Geometry) – A geometry object.

  • preserved_quantity (IntegralPreservationQuantity) – The quantity to preserve the integral of when converting to face values.

Return type:

Union[Float[Array, 'rhon+1'], Float[ndarray, 'rhon+1']]

Returns:

Values defined on the TORAX face grid.

torax.math_utils.cumulative_trapezoid(y, x=None, dx=1.0, axis=-1, initial=None)[source]

Cumulatively integrate y = f(x) using the trapezoid rule.

JAX equivalent of scipy.integrate.cumulative_trapezoid.

Parameters:
  • y (Array) – array of data to integrate.

  • x (Optional[Array]) – optional array of sample points corresponding to the y values. If not provided, x defaults to equally spaced with spacing given by dx.

  • dx (float) – the spacing between sample points when x is None (default: 1.0).

  • axis (int) – the axis along which to integrate (default: -1)

  • initial (Optional[float]) – a scalar value to prepend to the result. Either None (default) or 0.0. If initial=0, the result is an array with the same shape as y. If initial=None, the resulting array has one fewer elements than y along the axis dimension.

Return type:

Array

Returns:

The cumulative definite integral approximated by the trapezoidal rule.

torax.math_utils.line_average(value, geo)[source]

Calculates line-averaged value from input profile.

Parameters:
  • value (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • geo (Geometry)

Return type:

Union[Float[Array, ''], Float[ndarray, ''], number, float]

torax.math_utils.tridiag(diag, above, below)[source]

Builds a tridiagonal matrix.

Parameters:
  • diag (Array) – The main diagonal.

  • above (Array) – The +1 diagonal.

  • below (Array) – The -1 diagonal.

Return type:

Array

Returns:

The tridiagonal matrix.

torax.math_utils.volume_average(value, geo)[source]

Calculates volume-averaged value from input profile.

Parameters:
  • value (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • geo (Geometry)

Return type:

Union[Float[Array, ''], Float[ndarray, ''], number, float]

torax.math_utils.volume_integration(value, geo)[source]

Calculates integral of value using a volume metric.

Parameters:
  • value (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • geo (Geometry)

Return type:

Union[Float[Array, ''], Float[ndarray, ''], number, float]

torax.output module

Module containing functions for saving and loading simulation output.

class torax.output.StateHistory(state_history, post_processed_outputs_history, sim_error, source_models, torax_config)[source]

Bases: object

A history of the state of the simulation and its error state.

Parameters:
simulation_output_to_xr(file_restart=None)[source]

Build an xr.DataTree of the simulation output.

Parameters:

file_restart (Optional[FileRestart]) – If provided, contains information on a file this sim was restarted from, this is useful in case we want to stitch that to the beggining of this sim output.

Return type:

DataTree

Returns:

A xr.DataTree containing a single top level xr.Dataset and four child datasets. The top level dataset contains the following variables:

  • time: The time of the simulation.

  • rho_face_norm: The normalized toroidal coordinate on the face grid.

  • rho_cell_norm: The normalized toroidal coordinate on the cell grid.

  • rho_face: The toroidal coordinate on the face grid.

  • rho_cell: The toroidal coordinate on the cell grid.

  • sawtooth_crash: Time-series boolean indicating whether the

    state corresponds to a post-sawtooth-crash state.

  • sim_error: The simulation error state.

  • config: The ToraxConfig used to run the simulation serialized to JSON.

The child datasets contain the following variables:
  • core_profiles: Contains data variables for quantities in the CoreProfiles.

  • core_transport: Contains data variables for quantities in the CoreTransport.

  • core_sources: Contains data variables for quantities in the CoreSources.

  • post_processed_outputs: Contains data variables for quantities in the PostProcessedOutputs.

  • geometry: Contains data variables for quantities in the Geometry.

torax.output.concat_datatrees(tree1, tree2)[source]

Concats two xr.DataTrees along the time dimension.

For any duplicate time steps, the values from the first dataset are kept.

Parameters:
  • tree1 (DataTree) – The first xr.DataTree to concatenate.

  • tree2 (DataTree) – The second xr.DataTree to concatenate.

Return type:

DataTree

Returns:

A xr.DataTree containing the concatenation of the two datasets.

torax.output.load_state_file(filepath)[source]

Loads a state file from a filepath.

Parameters:

filepath (str)

Return type:

DataTree

torax.output.stitch_state_files(file_restart, datatree)[source]

Stitch a datatree to the end of a previous state file.

Parameters:
  • file_restart (FileRestart) – Contains information on a file this sim was restarted from.

  • datatree (DataTree) – The xr.DataTree to stitch to the end of the previous state file.

Return type:

DataTree

Returns:

A xr.DataTree containing the stitched dataset.

torax.post_processing module

Functions for adding post-processed outputs to the simulation state.

torax.post_processing.make_post_processed_outputs(sim_state, dynamic_runtime_params_slice, previous_post_processed_outputs=None)[source]

Calculates post-processed outputs based on the latest state.

Called at the beginning and end of each sim.run_simulation step. :type sim_state: ToraxSimState :param sim_state: The state to add outputs to. :type dynamic_runtime_params_slice: DynamicRuntimeParamsSlice :param dynamic_runtime_params_slice: Runtime parameters slice for the current time

step, needed for calculating integrated power.

Parameters:

previous_post_processed_outputs (Optional[PostProcessedOutputs]) – The previous outputs, used to calculate cumulative quantities. Optional input. If None, then cumulative quantities are set at the initialized values in sim_state itself. This is used for the first time step of a the simulation. The initialized values are zero for a clean simulation, or the last value of the previous simulation for a restarted simulation.

Returns:

The post_processed_outputs for the given state.

Return type:

post_processed_outputs

torax.run_simulation_main module

Main entrypoint for running transport simulation.

Example command with a configuration defined in Python: python3 run_simulation_main.py

–config=’torax.tests.test_data.default_config’ –log_progress

torax.run_simulation_main.use_jax_profiler_if_enabled(f)[source]

Decorator that runs func with profiling if the flag is enabled.

torax.sim module

Functionality for running simulations.

This includes the run_simulation main loop, logging functionality, and functionality for translating between our particular physics simulation and generic fluid dynamics PDE solvers.

Use the TORAX_COMPILATION_ENABLED environment variable to turn jax compilation off and on. Compilation is on by default. Turning compilation off can sometimes help with debugging (e.g. by making it easier to print error messages in context).

torax.state module

Classes defining the TORAX state that evolves over time.

class torax.state.CoreProfiles(temp_ion, temp_el, psi, psidot, ne, ni, nimp, currents, q_face, s_face, nref, vloop_lcfs, Zi, Zi_face, Ai, Zimp, Zimp_face, Aimp)[source]

Bases: Mapping

Dataclass for holding the evolving core plasma profiles.

This dataclass is inspired by the IMAS core_profiles IDS.

Many of the profiles in this class are evolved by the PDE system in TORAX, and therefore are stored as CellVariables. Other profiles are computed outside the internal PDE system, and are simple JAX arrays.

temp_ion

Ion temperature [keV].

temp_el

Electron temperature [keV].

psi

Poloidal flux [Wb].

psidot

Time derivative of poloidal flux (loop voltage) [V].

ne

Electron density [nref m^-3].

ni

Main ion density [nref m^-3].

nimp

Impurity density [nref m^-3].

currents

Instance of the Currents dataclass.

q_face

Safety factor.

s_face

Magnetic shear.

nref

Reference density [m^-3].

vloop_lcfs

Loop voltage at LCFS (V).

Zi

Main ion charge on cell grid [dimensionless].

Zi_face

Main ion charge on face grid [dimensionless].

Ai

Main ion mass [amu].

Zimp

Impurity charge on cell grid [dimensionless].

Zimp_face

Impurity charge on face grid [dimensionless].

Aimp

Impurity mass [amu].

Parameters:
index(i)[source]

If the CoreProfiles is a history, returns the i-th CoreProfiles.

Parameters:

i (int)

Return type:

Self

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
negative_temperature_or_density()[source]

Checks if any temperature or density is negative.

Return type:

bool

quasineutrality_satisfied()[source]

Checks if quasineutrality is satisfied.

Return type:

bool

values() an object providing a view on D's values
class torax.state.CoreTransport(chi_face_ion, chi_face_el, d_face_el, v_face_el, chi_face_el_bohm=None, chi_face_el_gyrobohm=None, chi_face_ion_bohm=None, chi_face_ion_gyrobohm=None)[source]

Bases: Mapping

Coefficients for the plasma transport.

These coefficients are computed by TORAX transport models. See the transport_model/ folder for more info.

NOTE: The naming of this class is inspired by the IMAS core_transport IDS, but its schema is not a 1:1 mapping to that IDS.

chi_face_ion

Ion heat conductivity, on the face grid.

chi_face_el

Electron heat conductivity, on the face grid.

d_face_el

Diffusivity of electron density, on the face grid.

v_face_el

Convection strength of electron density, on the face grid.

chi_face_el_bohm

(Optional) Bohm contribution for electron heat conductivity.

chi_face_el_gyrobohm

(Optional) GyroBohm contribution for electron heat conductivity.

chi_face_ion_bohm

(Optional) Bohm contribution for ion heat conductivity.

chi_face_ion_gyrobohm

(Optional) GyroBohm contribution for ion heat conductivity.

Parameters:
  • chi_face_ion (Array)

  • chi_face_el (Array)

  • d_face_el (Array)

  • v_face_el (Array)

  • chi_face_el_bohm (Optional[Array])

  • chi_face_el_gyrobohm (Optional[Array])

  • chi_face_ion_bohm (Optional[Array])

  • chi_face_ion_gyrobohm (Optional[Array])

chi_max(geo)[source]

Calculates the maximum value of chi.

Parameters:

geo (Geometry) – Geometry of the torus.

Returns:

Maximum value of chi.

Return type:

chi_max

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
classmethod zeros(geo)[source]

Returns a CoreTransport with all zeros. Useful for initializing.

Parameters:

geo (Geometry)

Return type:

Self

class torax.state.Currents(jtot, jtot_face, johm, external_current_source, j_bootstrap, j_bootstrap_face, I_bootstrap, Ip_profile_face, sigma, jtot_hires=None)[source]

Bases: Mapping

Dataclass to group currents and related variables (e.g. conductivity).

Not all fields are actually used by the library. For example, j_bootstrap and I_bootstrap are updated during the sim loop, but not read from. These fields are an output of the library that may be interesting for the end user to plot, etc.

Parameters:
  • jtot (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • jtot_face (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • johm (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • external_current_source (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • j_bootstrap (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • j_bootstrap_face (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • I_bootstrap (Union[Float[Array, ''], Float[ndarray, ''], number, float])

  • Ip_profile_face (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • sigma (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon']])

  • jtot_hires (Union[Float[Array, 'rhon'], Float[ndarray, 'rhon'], None])

property Ip_total: Float[Array, ''] | Float[ndarray, ''] | number | float

Returns the total plasma current [A].

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
classmethod zeros(geo)[source]

Returns a Currents with all zeros.

Parameters:

geo (Geometry)

Return type:

Currents

class torax.state.PostProcessedOutputs(pressure_thermal_ion_face, pressure_thermal_el_face, pressure_thermal_tot_face, pprime_face, W_thermal_ion, W_thermal_el, W_thermal_tot, tauE, H89P, H98, H97L, H20, FFprime_face, psi_norm_face, psi_face, P_sol_ion, P_sol_el, P_sol_tot, P_external_ion, P_external_el, P_external_tot, P_external_injected, P_ei_exchange_ion, P_ei_exchange_el, P_generic_ion, P_generic_el, P_generic_tot, P_alpha_ion, P_alpha_el, P_alpha_tot, P_ohmic, P_brems, P_cycl, P_ecrh, P_rad, I_ecrh, I_generic, Q_fusion, P_icrh_el, P_icrh_ion, P_icrh_tot, P_LH_hi_dens, P_LH_min, P_LH, ne_min_P_LH, E_cumulative_fusion, E_cumulative_external, te_volume_avg, ti_volume_avg, ne_volume_avg, ni_volume_avg, ne_line_avg, ni_line_avg, fgw_ne_volume_avg, fgw_ne_line_avg, q95, Wpol, li3, dW_th_dt)[source]

Bases: Mapping

Collection of outputs calculated after each simulation step.

These variables are not used internally, but are useful as outputs or intermediate observations for overarching workflows.

pressure_thermal_ion_face

Ion thermal pressure on the face grid [Pa]

pressure_thermal_el_face

Electron thermal pressure on the face grid [Pa]

pressure_thermal_tot_face

Total thermal pressure on the face grid [Pa]

pprime_face

Derivative of total pressure with respect to poloidal flux on the face grid [Pa/Wb]

W_thermal_ion

Ion thermal stored energy [J]

W_thermal_el

Electron thermal stored energy [J]

W_thermal_tot

Total thermal stored energy [J]

tauE

Thermal energy confinement time [s]

H89P

L-mode confinement quality factor with respect to the ITER89P scaling law derived from the ITER L-mode confinement database

H98

H-mode confinement quality factor with respect to the ITER98y2 scaling law derived from the ITER H-mode confinement database

H97L

L-mode confinement quality factor with respect to the ITER97L scaling law derived from the ITER L-mode confinement database

H20

H-mode confinement quality factor with respect to the ITER20 scaling law derived from the updated (2020) ITER H-mode confinement database

FFprime_face

FF’ on the face grid, where F is the toroidal flux function

psi_norm_face

Normalized poloidal flux on the face grid [Wb]

psi_face

Poloidal flux on the face grid [Wb]

P_sol_ion

Total ion heating power exiting the plasma with all sources: auxiliary heating + ion-electron exchange + fusion [W]

P_sol_el

Total electron heating power exiting the plasma with all sources and sinks: auxiliary heating + ion-electron exchange + Ohmic + fusion + radiation sinks [W]

P_sol_tot

Total heating power exiting the plasma with all sources and sinks

P_external_ion

Total external ion heating power: auxiliary heating + Ohmic [W]

P_external_el

Total external electron heating power: auxiliary heating + Ohmic [W]

P_external_tot

Total external heating power: auxiliary heating + Ohmic [W]

P_external_injected

Total external injected power before absorption [W]

P_ei_exchange_ion

Electron-ion heat exchange power to ions [W]

P_ei_exchange_el

Electron-ion heat exchange power to electrons [W]

P_generic_ion

Total generic_ion_el_heat_source power to ions [W]

P_generic_el

Total generic_ion_el_heat_source power to electrons [W]

P_generic_tot

Total generic_ion_el_heat power [W]

P_alpha_ion

Total fusion power to ions [W]

P_alpha_el

Total fusion power to electrons [W]

P_alpha_tot

Total fusion power to plasma [W]

P_ohmic

Ohmic heating power to electrons [W]

P_brems

Bremsstrahlung electron heat sink [W]

P_cycl

Cyclotron radiation electron heat sink [W]

P_ecrh

Total electron cyclotron source power [W]

P_rad

Impurity radiation heat sink [W]

I_ecrh

Total electron cyclotron source current [A]

I_generic

Total generic source current [A]

Q_fusion

Fusion power gain

P_icrh_el

Ion cyclotron resonance heating to electrons [W]

P_icrh_ion

Ion cyclotron resonance heating to ions [W]

P_icrh_tot

Total ion cyclotron resonance heating power [W]

P_LH_hi_dens

H-mode transition power for high density branch [W]

P_LH_min

Minimum H-mode transition power for at ne_min_P_LH [W]

P_LH

H-mode transition power from maximum of P_LH_hi_dens and P_LH_min [W]

ne_min_P_LH

Density corresponding to the P_LH_min [nref]

E_cumulative_fusion

Total cumulative fusion energy [J]

E_cumulative_external

Total external injected energy (Ohmic + auxiliary heating) [J]

te_volume_avg

Volume average electron temperature [keV]

ti_volume_avg

Volume average ion temperature [keV]

ne_volume_avg

Volume average electron density [nref m^-3]

ni_volume_avg

Volume average main ion density [nref m^-3]

ne_line_avg

Line averaged electron density [nref m^-3]

ni_line_avg

Line averaged main ion density [nref m^-3]

fgw_ne_volume_avg

Greenwald fraction from volume-averaged electron density [dimensionless]

fgw_ne_line_avg

Greenwald fraction from line-averaged electron density [dimensionless]

q95

q at 95% of the normalized poloidal flux

Wpol

Total magnetic energy [J]

li3

Normalized plasma internal inductance, ITER convention [dimensionless]

dW_th_dt

Time derivative of the total stored thermal energy [W]

Parameters:
items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
classmethod zeros(geo)[source]

Returns a PostProcessedOutputs with all zeros, used for initializing.

Parameters:

geo (Geometry)

Return type:

Self

class torax.state.SimError(value)[source]

Bases: Enum

Integer enum for sim error handling.

class torax.state.StepperNumericOutputs(outer_stepper_iterations=0, stepper_error_state=0, inner_solver_iterations=0)[source]

Bases: Mapping

Numerical quantities related to the stepper.

outer_stepper_iterations

Number of iterations performed in the outer loop of the stepper.

stepper_error_state

0 if solver converged with fine tolerance for this step 1 if solver did not converge for this step (was above coarse tol) 2 if solver converged within coarse tolerance. Allowed to pass with a warning. Occasional error=2 has low impact on final sim state.

inner_solver_iterations

Total number of iterations performed in the solver across all iterations of the stepper.

Parameters:
  • outer_stepper_iterations (int)

  • stepper_error_state (int)

  • inner_solver_iterations (int)

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class torax.state.ToraxSimState(t, dt, core_profiles, core_transport, core_sources, geometry, stepper_numeric_outputs, sawtooth_crash=False)[source]

Bases: Mapping

Full simulator state.

The simulation stepping in sim.py evolves core_profiles which includes all the attributes the simulation is advancing. But beyond those, there are additional stateful elements which may evolve on each simulation step, such as sources and transport.

This class includes both core_profiles and these additional elements.

t

time coordinate.

dt

timestep interval.

core_profiles

Core plasma profiles at time t.

core_transport

Core plasma transport coefficients computed at time t.

core_sources

Profiles for all sources/sinks. These are the profiles that are used to calculate the coefficients for the t+dt time step. For the explicit sources, these are calculated at the start of the time step, so are the values at time t. For the implicit sources, these are the most recent guess for time t+dt. The profiles here are the merged version of the explicit and implicit profiles.

post_processed_outputs

variables for output or intermediate observations for overarching workflows, calculated after each simulation step.

geometry

Geometry at this time step used for the simulation.

time_step_calculator_state

the state of the TimeStepper.

stepper_numeric_outputs

Numerical quantities related to the stepper.

sawtooth_crash

True if a sawtooth model is active and the state corresponds to a post-sawtooth-crash state.

Parameters:
check_for_errors()[source]

Checks for errors in the simulation state.

Return type:

SimError

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
torax.state.check_for_errors(sim_state, post_processed_outputs)[source]

Checks for errors in the simulation state.

Parameters:
Return type:

SimError

torax.version module

Torax version information.

Module contents

Library functionality for TORAX.

class torax.InterpolatedVarSingleAxis(value, interpolation_mode=InterpolationMode.PIECEWISE_LINEAR, is_bool_param=False)[source]

Bases: InterpolatedParamBase

Parameter that may vary based on an input coordinate.

This class is useful for defining time-dependent runtime parameters, but can be used to define any parameters that vary across some range.

This function allows the interpolation of a 1d array xs, against either a 1d or 2d array ys. For example, xs can be time, and ys either a 1d array of scalars associated to the times in xs, or a 2d array where the index 0 in ys associates a radial array in the index 1 with the times in xs. The interpolation of the 2d array is then carried out element-wise and accelerated with vmap. Intended use of ys being a 2d array is when the radial slices on index 1 have already been interpolated onto appropriate TORAX grids, such as cell_centers, faces, or the hires grid. NOTE: this means that the 2d array should have shape (n, m) where n is the number of elements in the 1d array and m is the number of spatial grid size of the InterpolatedVar1d instance

See config.runtime_params.RuntimeParams and associated tests to see how this is used.

Parameters:
get_value(x)[source]

Returns a single value for this range at the given coordinate.

Parameters:

x (Union[Array, ndarray, bool, number, float, int])

Return type:

Union[Array, ndarray, bool, number]

property interpolation_mode: InterpolationMode

Returns the interpolation mode used by this param.

property is_bool_param: bool

Returns whether this param represents a bool.

property param: InterpolatedParamBase

Returns the JAX-friendly interpolated param used under the hood.

class torax.InterpolatedVarTimeRho(values, rho_norm, time_interpolation_mode=InterpolationMode.PIECEWISE_LINEAR, rho_interpolation_mode=InterpolationMode.PIECEWISE_LINEAR)[source]

Bases: InterpolatedParamBase

Interpolates on a grid (time, rho).

This class linearly interpolates along time to provide a value at any (time, rho) pair. For time values that are outside the range of values the closest defined InterpolatedVarSingleAxis is used.

  • NOTE: We assume that rho interpolation is fixed per simulation so take this

at init and take just time at get_value.

Parameters:
get_value(x)[source]

Returns the value of this parameter interpolated at x=time.

Parameters:

x (Union[Array, ndarray, bool, number, float, int])

Return type:

Union[Array, ndarray, bool, number]

property rho_interpolation_mode: InterpolationMode

Returns the rho interpolation mode used by this param.

property time_interpolation_mode: InterpolationMode

Returns the time interpolation mode used by this param.

class torax.InterpolationMode(value)[source]

Bases: Enum

Defines how to do the interpolation.

InterpolatedParams have many values to interpolate between, and this enum defines how exactly that interpolation is computed.

Assuming inputs [x_0, …, x_n] and [y_0, …, y_n], for all modes, the interpolated param outputs y_0 for any input less than x_0 and y_n for any input greater than x_n.

Options:
PIECEWISE_LINEAR: Does piecewise-linear interpolation between the values

provided. See numpy.interp for a longer description of how it works. (This uses JAX, but the behavior is the same.)

STEP: Step-function interpolation. For any input value x in the range [x_k,

x_k+1), the output will be y_k.

class torax.SimError(value)[source]

Bases: Enum

Integer enum for sim error handling.

class torax.ToraxConfig(**data)[source]

Bases: BaseModelFrozen

Base config class for Torax.

profile_conditions

Config for the profile conditions.

numerics

Config for the numerics.

plasma_composition

Config for the plasma composition.

geometry

Config for the geometry.

pedestal

Config for the pedestal model. If an empty dictionary is passed in, the pedestal model will be set to no_pedestal.

sources

Config for the sources.

stepper

Config for the stepper. If an empty dictionary is passed in, the stepper model will be set to linear.

transport

Config for the transport model. If an empty dictionary is passed in, the transport model will be set to constant.

mhd

Optional config for mhd models. If None, no MHD models are used.

time_step_calculator

Optional config for the time step calculator. If not provided the default chi time step calculator is used.

restart

Optional config for file restart. If None, no file restart is performed.

Parameters:

data (Any)

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'frozen': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

update_fields(x)[source]

Safely update fields in the config.

This works with Frozen models.

This method will invalidate all functools.cached_property caches of all ancestral models in the nested tree, as these could have a dependency on the updated model. In addition, these nodes will be re-validated.

Parameters:

x (Mapping[str, Any]) – A dictionary whose key is a path ‘some.path.to.field_name’ and the value is the new value for field_name. The path can be dictionary keys or attribute names, but field_name must be an attribute of a Pydantic model.

Raises:

ValueError – all submodels must be unique object instances. A ValueError will be raised if this is not the case.

torax.import_module(module_name, config_package=None)[source]

Imports a module.

Parameters:
torax.run_simulation(torax_config, log_timestep_info=False, progress_bar=True)[source]

Runs a TORAX simulation using the config and returns the outputs.

Parameters:
Return type:

StateHistory