Source code for torax.state

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes defining the TORAX state that evolves over time."""
import enum
from typing import Optional

from absl import logging
import chex
import jax
from jax import numpy as jnp
from torax import array_typing
from torax import jax_utils
from torax.config import config_args
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.sources import source_profiles
import typing_extensions


[docs] @chex.dataclass(frozen=True) class Currents: """Dataclass to group currents and related variables (e.g. conductivity). Not all fields are actually used by the library. For example, j_bootstrap and I_bootstrap are updated during the sim loop, but not read from. These fields are an output of the library that may be interesting for the end user to plot, etc. """ jtot: array_typing.ArrayFloat jtot_face: array_typing.ArrayFloat johm: array_typing.ArrayFloat external_current_source: array_typing.ArrayFloat j_bootstrap: array_typing.ArrayFloat j_bootstrap_face: array_typing.ArrayFloat # pylint: disable=invalid-name # Using physics notation naming convention I_bootstrap: array_typing.ScalarFloat # [A] Ip_profile_face: array_typing.ArrayFloat # [A] sigma: array_typing.ArrayFloat jtot_hires: Optional[array_typing.ArrayFloat] = None @property def Ip_total(self) -> array_typing.ScalarFloat: """Returns the total plasma current [A].""" return self.Ip_profile_face[..., -1]
[docs] @classmethod def zeros(cls, geo: geometry.Geometry) -> "Currents": """Returns a Currents with all zeros.""" return cls( jtot=jnp.zeros(geo.rho_face.shape), jtot_face=jnp.zeros(geo.rho_face.shape), johm=jnp.zeros(geo.rho_face.shape), external_current_source=jnp.zeros(geo.rho_face.shape), j_bootstrap=jnp.zeros(geo.rho_face.shape), j_bootstrap_face=jnp.zeros(geo.rho_face.shape), I_bootstrap=jnp.array(0.0, dtype=jax_utils.get_dtype()), Ip_profile_face=jnp.zeros(geo.rho_face.shape), sigma=jnp.zeros(geo.rho_face.shape), jtot_hires=jnp.zeros(geo.rho_face.shape), )
[docs] @chex.dataclass(frozen=True, eq=False) class CoreProfiles: """Dataclass for holding the evolving core plasma profiles. This dataclass is inspired by the IMAS `core_profiles` IDS. Many of the profiles in this class are evolved by the PDE system in TORAX, and therefore are stored as CellVariables. Other profiles are computed outside the internal PDE system, and are simple JAX arrays. Attributes: temp_ion: Ion temperature [keV]. temp_el: Electron temperature [keV]. psi: Poloidal flux [Wb]. psidot: Time derivative of poloidal flux (loop voltage) [V]. ne: Electron density [nref m^-3]. ni: Main ion density [nref m^-3]. nimp: Impurity density [nref m^-3]. currents: Instance of the Currents dataclass. q_face: Safety factor. s_face: Magnetic shear. nref: Reference density [m^-3]. vloop_lcfs: Loop voltage at LCFS (V). Zi: Main ion charge on cell grid [dimensionless]. Zi_face: Main ion charge on face grid [dimensionless]. Ai: Main ion mass [amu]. Zimp: Impurity charge on cell grid [dimensionless]. Zimp_face: Impurity charge on face grid [dimensionless]. Aimp: Impurity mass [amu]. """ temp_ion: cell_variable.CellVariable temp_el: cell_variable.CellVariable psi: cell_variable.CellVariable psidot: cell_variable.CellVariable ne: cell_variable.CellVariable ni: cell_variable.CellVariable nimp: cell_variable.CellVariable currents: Currents q_face: array_typing.ArrayFloat s_face: array_typing.ArrayFloat nref: array_typing.ScalarFloat vloop_lcfs: array_typing.ScalarFloat # pylint: disable=invalid-name Zi: array_typing.ArrayFloat Zi_face: array_typing.ArrayFloat Ai: array_typing.ScalarFloat Zimp: array_typing.ArrayFloat Zimp_face: array_typing.ArrayFloat Aimp: array_typing.ScalarFloat # pylint: enable=invalid-name
[docs] def quasineutrality_satisfied(self) -> bool: """Checks if quasineutrality is satisfied.""" return jnp.allclose( self.ni.value * self.Zi + self.nimp.value * self.Zimp, self.ne.value, ).item()
[docs] def negative_temperature_or_density(self) -> bool: """Checks if any temperature or density is negative.""" profiles_to_check = ( self.temp_ion, self.temp_el, self.ne, self.ni, self.nimp, ) return any( [jnp.any(jnp.less(x, 0.0)) for x in jax.tree.leaves(profiles_to_check)] )
[docs] def index(self, i: int) -> typing_extensions.Self: """If the CoreProfiles is a history, returns the i-th CoreProfiles.""" idx = lambda x: x[i] state = jax.tree_util.tree_map(idx, self) # These variables track whether they are histories, so when we collapse down # to a single state we need to explicitly clear the history flag. history_vars = ["temp_ion", "temp_el", "psi", "psidot", "ne", "ni"] history_replace = {"history": None} replace_dict = {var: history_replace for var in history_vars} state = config_args.recursive_replace(state, **replace_dict) return state
def sanity_check(self): for field in CoreProfiles.__dataclass_fields__: value = getattr(self, field) if hasattr(value, "sanity_check"): value.sanity_check() def __str__(self) -> str: return f""" CoreProfiles( temp_ion={self.temp_ion}, temp_el={self.temp_el}, psi={self.psi}, ne={self.ne}, nimp={self.nimp}, ni={self.ni}, ) """
[docs] @chex.dataclass class CoreTransport: """Coefficients for the plasma transport. These coefficients are computed by TORAX transport models. See the transport_model/ folder for more info. NOTE: The naming of this class is inspired by the IMAS `core_transport` IDS, but its schema is not a 1:1 mapping to that IDS. Attributes: chi_face_ion: Ion heat conductivity, on the face grid. chi_face_el: Electron heat conductivity, on the face grid. d_face_el: Diffusivity of electron density, on the face grid. v_face_el: Convection strength of electron density, on the face grid. chi_face_el_bohm: (Optional) Bohm contribution for electron heat conductivity. chi_face_el_gyrobohm: (Optional) GyroBohm contribution for electron heat conductivity. chi_face_ion_bohm: (Optional) Bohm contribution for ion heat conductivity. chi_face_ion_gyrobohm: (Optional) GyroBohm contribution for ion heat conductivity. """ chi_face_ion: jax.Array chi_face_el: jax.Array d_face_el: jax.Array v_face_el: jax.Array chi_face_el_bohm: Optional[jax.Array] = None chi_face_el_gyrobohm: Optional[jax.Array] = None chi_face_ion_bohm: Optional[jax.Array] = None chi_face_ion_gyrobohm: Optional[jax.Array] = None def __post_init__(self): # Use the array size of chi_face_el as a reference. if self.chi_face_el_bohm is None: self.chi_face_el_bohm = jnp.zeros_like(self.chi_face_el) if self.chi_face_el_gyrobohm is None: self.chi_face_el_gyrobohm = jnp.zeros_like(self.chi_face_el) if self.chi_face_ion_bohm is None: self.chi_face_ion_bohm = jnp.zeros_like(self.chi_face_el) if self.chi_face_ion_gyrobohm is None: self.chi_face_ion_gyrobohm = jnp.zeros_like(self.chi_face_el)
[docs] def chi_max( self, geo: geometry.Geometry, ) -> jax.Array: """Calculates the maximum value of chi. Args: geo: Geometry of the torus. Returns: chi_max: Maximum value of chi. """ return jnp.maximum( jnp.max(self.chi_face_ion * geo.g1_over_vpr2_face), jnp.max(self.chi_face_el * geo.g1_over_vpr2_face), )
[docs] @classmethod def zeros(cls, geo: geometry.Geometry) -> typing_extensions.Self: """Returns a CoreTransport with all zeros. Useful for initializing.""" shape = geo.rho_face.shape return cls( chi_face_ion=jnp.zeros(shape), chi_face_el=jnp.zeros(shape), d_face_el=jnp.zeros(shape), v_face_el=jnp.zeros(shape), chi_face_el_bohm=jnp.zeros(shape), chi_face_el_gyrobohm=jnp.zeros(shape), chi_face_ion_bohm=jnp.zeros(shape), chi_face_ion_gyrobohm=jnp.zeros(shape), )
[docs] @chex.dataclass(frozen=True, eq=False) class PostProcessedOutputs: """Collection of outputs calculated after each simulation step. These variables are not used internally, but are useful as outputs or intermediate observations for overarching workflows. Attributes: pressure_thermal_ion_face: Ion thermal pressure on the face grid [Pa] pressure_thermal_el_face: Electron thermal pressure on the face grid [Pa] pressure_thermal_tot_face: Total thermal pressure on the face grid [Pa] pprime_face: Derivative of total pressure with respect to poloidal flux on the face grid [Pa/Wb] W_thermal_ion: Ion thermal stored energy [J] W_thermal_el: Electron thermal stored energy [J] W_thermal_tot: Total thermal stored energy [J] tauE: Thermal energy confinement time [s] H89P: L-mode confinement quality factor with respect to the ITER89P scaling law derived from the ITER L-mode confinement database H98: H-mode confinement quality factor with respect to the ITER98y2 scaling law derived from the ITER H-mode confinement database H97L: L-mode confinement quality factor with respect to the ITER97L scaling law derived from the ITER L-mode confinement database H20: H-mode confinement quality factor with respect to the ITER20 scaling law derived from the updated (2020) ITER H-mode confinement database FFprime_face: FF' on the face grid, where F is the toroidal flux function psi_norm_face: Normalized poloidal flux on the face grid [Wb] psi_face: Poloidal flux on the face grid [Wb] P_sol_ion: Total ion heating power exiting the plasma with all sources: auxiliary heating + ion-electron exchange + fusion [W] P_sol_el: Total electron heating power exiting the plasma with all sources and sinks: auxiliary heating + ion-electron exchange + Ohmic + fusion + radiation sinks [W] P_sol_tot: Total heating power exiting the plasma with all sources and sinks P_external_ion: Total external ion heating power: auxiliary heating + Ohmic [W] P_external_el: Total external electron heating power: auxiliary heating + Ohmic [W] P_external_tot: Total external heating power: auxiliary heating + Ohmic [W] P_external_injected: Total external injected power before absorption [W] P_ei_exchange_ion: Electron-ion heat exchange power to ions [W] P_ei_exchange_el: Electron-ion heat exchange power to electrons [W] P_generic_ion: Total generic_ion_el_heat_source power to ions [W] P_generic_el: Total generic_ion_el_heat_source power to electrons [W] P_generic_tot: Total generic_ion_el_heat power [W] P_alpha_ion: Total fusion power to ions [W] P_alpha_el: Total fusion power to electrons [W] P_alpha_tot: Total fusion power to plasma [W] P_ohmic: Ohmic heating power to electrons [W] P_brems: Bremsstrahlung electron heat sink [W] P_cycl: Cyclotron radiation electron heat sink [W] P_ecrh: Total electron cyclotron source power [W] P_rad: Impurity radiation heat sink [W] I_ecrh: Total electron cyclotron source current [A] I_generic: Total generic source current [A] Q_fusion: Fusion power gain P_icrh_el: Ion cyclotron resonance heating to electrons [W] P_icrh_ion: Ion cyclotron resonance heating to ions [W] P_icrh_tot: Total ion cyclotron resonance heating power [W] P_LH_hi_dens: H-mode transition power for high density branch [W] P_LH_min: Minimum H-mode transition power for at ne_min_P_LH [W] P_LH: H-mode transition power from maximum of P_LH_hi_dens and P_LH_min [W] ne_min_P_LH: Density corresponding to the P_LH_min [nref] E_cumulative_fusion: Total cumulative fusion energy [J] E_cumulative_external: Total external injected energy (Ohmic + auxiliary heating) [J] te_volume_avg: Volume average electron temperature [keV] ti_volume_avg: Volume average ion temperature [keV] ne_volume_avg: Volume average electron density [nref m^-3] ni_volume_avg: Volume average main ion density [nref m^-3] ne_line_avg: Line averaged electron density [nref m^-3] ni_line_avg: Line averaged main ion density [nref m^-3] fgw_ne_volume_avg: Greenwald fraction from volume-averaged electron density [dimensionless] fgw_ne_line_avg: Greenwald fraction from line-averaged electron density [dimensionless] q95: q at 95% of the normalized poloidal flux Wpol: Total magnetic energy [J] li3: Normalized plasma internal inductance, ITER convention [dimensionless] dW_th_dt: Time derivative of the total stored thermal energy [W] """ pressure_thermal_ion_face: array_typing.ArrayFloat pressure_thermal_el_face: array_typing.ArrayFloat pressure_thermal_tot_face: array_typing.ArrayFloat pprime_face: array_typing.ArrayFloat # pylint: disable=invalid-name W_thermal_ion: array_typing.ScalarFloat W_thermal_el: array_typing.ScalarFloat W_thermal_tot: array_typing.ScalarFloat tauE: array_typing.ScalarFloat H89P: array_typing.ScalarFloat H98: array_typing.ScalarFloat H97L: array_typing.ScalarFloat H20: array_typing.ScalarFloat FFprime_face: array_typing.ArrayFloat psi_norm_face: array_typing.ArrayFloat # psi_face included in post_processed output for convenience, since the # CellVariable history method destroys class methods like `face_value`. psi_face: array_typing.ArrayFloat # Integrated heat sources P_sol_ion: array_typing.ScalarFloat # SOL stands for "Scrape Off Layer" P_sol_el: array_typing.ScalarFloat P_sol_tot: array_typing.ScalarFloat P_external_ion: array_typing.ScalarFloat P_external_el: array_typing.ScalarFloat P_external_tot: array_typing.ScalarFloat P_external_injected: array_typing.ScalarFloat P_ei_exchange_ion: array_typing.ScalarFloat P_ei_exchange_el: array_typing.ScalarFloat P_generic_ion: array_typing.ScalarFloat P_generic_el: array_typing.ScalarFloat P_generic_tot: array_typing.ScalarFloat P_alpha_ion: array_typing.ScalarFloat P_alpha_el: array_typing.ScalarFloat P_alpha_tot: array_typing.ScalarFloat P_ohmic: array_typing.ScalarFloat P_brems: array_typing.ScalarFloat P_cycl: array_typing.ScalarFloat P_ecrh: array_typing.ScalarFloat P_rad: array_typing.ScalarFloat I_ecrh: array_typing.ScalarFloat I_generic: array_typing.ScalarFloat Q_fusion: array_typing.ScalarFloat P_icrh_el: array_typing.ScalarFloat P_icrh_ion: array_typing.ScalarFloat P_icrh_tot: array_typing.ScalarFloat P_LH_hi_dens: array_typing.ScalarFloat P_LH_min: array_typing.ScalarFloat P_LH: array_typing.ScalarFloat ne_min_P_LH: array_typing.ScalarFloat E_cumulative_fusion: array_typing.ScalarFloat E_cumulative_external: array_typing.ScalarFloat te_volume_avg: array_typing.ScalarFloat ti_volume_avg: array_typing.ScalarFloat ne_volume_avg: array_typing.ScalarFloat ni_volume_avg: array_typing.ScalarFloat ne_line_avg: array_typing.ScalarFloat ni_line_avg: array_typing.ScalarFloat fgw_ne_volume_avg: array_typing.ScalarFloat fgw_ne_line_avg: array_typing.ScalarFloat q95: array_typing.ScalarFloat Wpol: array_typing.ScalarFloat li3: array_typing.ScalarFloat dW_th_dt: array_typing.ScalarFloat # pylint: enable=invalid-name
[docs] @classmethod def zeros(cls, geo: geometry.Geometry) -> typing_extensions.Self: """Returns a PostProcessedOutputs with all zeros, used for initializing.""" return cls( pressure_thermal_ion_face=jnp.zeros(geo.rho_face.shape), pressure_thermal_el_face=jnp.zeros(geo.rho_face.shape), pressure_thermal_tot_face=jnp.zeros(geo.rho_face.shape), pprime_face=jnp.zeros(geo.rho_face.shape), W_thermal_ion=jnp.array(0.0, dtype=jax_utils.get_dtype()), W_thermal_el=jnp.array(0.0, dtype=jax_utils.get_dtype()), W_thermal_tot=jnp.array(0.0, dtype=jax_utils.get_dtype()), tauE=jnp.array(0.0, dtype=jax_utils.get_dtype()), H89P=jnp.array(0.0, dtype=jax_utils.get_dtype()), H98=jnp.array(0.0, dtype=jax_utils.get_dtype()), H97L=jnp.array(0.0, dtype=jax_utils.get_dtype()), H20=jnp.array(0.0, dtype=jax_utils.get_dtype()), FFprime_face=jnp.zeros(geo.rho_face.shape), psi_norm_face=jnp.zeros(geo.rho_face.shape), psi_face=jnp.zeros(geo.rho_face.shape), P_sol_ion=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_sol_el=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_sol_tot=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_external_ion=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_external_el=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_external_tot=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_external_injected=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_ei_exchange_ion=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_ei_exchange_el=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_generic_ion=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_generic_el=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_generic_tot=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_alpha_ion=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_alpha_el=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_alpha_tot=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_ohmic=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_brems=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_cycl=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_ecrh=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_rad=jnp.array(0.0, dtype=jax_utils.get_dtype()), I_ecrh=jnp.array(0.0, dtype=jax_utils.get_dtype()), I_generic=jnp.array(0.0, dtype=jax_utils.get_dtype()), Q_fusion=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_icrh_ion=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_icrh_el=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_icrh_tot=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_LH_hi_dens=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_LH_min=jnp.array(0.0, dtype=jax_utils.get_dtype()), P_LH=jnp.array(0.0, dtype=jax_utils.get_dtype()), ne_min_P_LH=jnp.array(0.0, dtype=jax_utils.get_dtype()), E_cumulative_fusion=jnp.array(0.0, dtype=jax_utils.get_dtype()), E_cumulative_external=jnp.array(0.0, dtype=jax_utils.get_dtype()), te_volume_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), ti_volume_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), ne_volume_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), ni_volume_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), ne_line_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), ni_line_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), fgw_ne_volume_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), fgw_ne_line_avg=jnp.array(0.0, dtype=jax_utils.get_dtype()), q95=jnp.array(0.0, dtype=jax_utils.get_dtype()), Wpol=jnp.array(0.0, dtype=jax_utils.get_dtype()), li3=jnp.array(0.0, dtype=jax_utils.get_dtype()), dW_th_dt=jnp.array(0.0, dtype=jax_utils.get_dtype()), )
def check_for_errors(self): if has_nan(self): return SimError.NAN_DETECTED else: return SimError.NO_ERROR
[docs] @chex.dataclass class StepperNumericOutputs: """Numerical quantities related to the stepper. Attributes: outer_stepper_iterations: Number of iterations performed in the outer loop of the stepper. stepper_error_state: 0 if solver converged with fine tolerance for this step 1 if solver did not converge for this step (was above coarse tol) 2 if solver converged within coarse tolerance. Allowed to pass with a warning. Occasional error=2 has low impact on final sim state. inner_solver_iterations: Total number of iterations performed in the solver across all iterations of the stepper. """ outer_stepper_iterations: int = 0 stepper_error_state: int = 0 inner_solver_iterations: int = 0
[docs] @enum.unique class SimError(enum.Enum): """Integer enum for sim error handling.""" NO_ERROR = 0 NAN_DETECTED = 1 QUASINEUTRALITY_BROKEN = 2 NEGATIVE_CORE_PROFILES = 3 def log_error(self): match self: case SimError.NEGATIVE_CORE_PROFILES: logging.error(""" Simulation stopped due to negative values in core profiles. """) case SimError.NAN_DETECTED: logging.error(""" Simulation stopped due to NaNs in state. Output file contains all profiles up to the last valid step. """) case SimError.QUASINEUTRALITY_BROKEN: logging.error(""" Simulation stopped due to quasineutrality being violated. Possible cause is bad handling of impurity species. Output file contains all profiles up to the last valid step. """) case SimError.NO_ERROR: pass case _: raise ValueError(f"Unknown SimError: {self}")
[docs] @chex.dataclass class ToraxSimState: """Full simulator state. The simulation stepping in sim.py evolves core_profiles which includes all the attributes the simulation is advancing. But beyond those, there are additional stateful elements which may evolve on each simulation step, such as sources and transport. This class includes both core_profiles and these additional elements. Attributes: t: time coordinate. dt: timestep interval. core_profiles: Core plasma profiles at time t. core_transport: Core plasma transport coefficients computed at time t. core_sources: Profiles for all sources/sinks. These are the profiles that are used to calculate the coefficients for the t+dt time step. For the explicit sources, these are calculated at the start of the time step, so are the values at time t. For the implicit sources, these are the most recent guess for time t+dt. The profiles here are the merged version of the explicit and implicit profiles. post_processed_outputs: variables for output or intermediate observations for overarching workflows, calculated after each simulation step. geometry: Geometry at this time step used for the simulation. time_step_calculator_state: the state of the TimeStepper. stepper_numeric_outputs: Numerical quantities related to the stepper. sawtooth_crash: True if a sawtooth model is active and the state corresponds to a post-sawtooth-crash state. """ t: jax.Array dt: jax.Array core_profiles: CoreProfiles core_transport: CoreTransport core_sources: source_profiles.SourceProfiles geometry: geometry.Geometry stepper_numeric_outputs: StepperNumericOutputs sawtooth_crash: bool = False
[docs] def check_for_errors(self) -> SimError: """Checks for errors in the simulation state.""" if self.core_profiles.negative_temperature_or_density(): logging.info("%s", self.core_profiles) log_negative_profile_names(self.core_profiles) return SimError.NEGATIVE_CORE_PROFILES # If there are NaNs that occured without negative core profiles, log this # as a separate error. if has_nan(self): logging.info("%s", self.core_profiles) return SimError.NAN_DETECTED elif not self.core_profiles.quasineutrality_satisfied(): return SimError.QUASINEUTRALITY_BROKEN else: return SimError.NO_ERROR
def has_nan(inputs: ToraxSimState | PostProcessedOutputs) -> bool: return any([jnp.any(jnp.isnan(x)) for x in jax.tree.leaves(inputs)]) def log_negative_profile_names(inputs: CoreProfiles): path_vals, _ = jax.tree.flatten_with_path(inputs) for path, value in path_vals: if jnp.any(jnp.less(value, 0.0)): logging.info("Found negative value in %s", jax.tree_util.keystr(path))
[docs] def check_for_errors( sim_state: ToraxSimState, post_processed_outputs: PostProcessedOutputs, ) -> SimError: """Checks for errors in the simulation state.""" state_error = sim_state.check_for_errors() if state_error != SimError.NO_ERROR: return state_error else: return post_processed_outputs.check_for_errors()