# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing functions for saving and loading simulation output."""
import dataclasses
import inspect
from absl import logging
import chex
import jax
import numpy as np
from torax import state
from torax.geometry import geometry as geometry_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles
from torax.torax_pydantic import file_restart as file_restart_pydantic_model
from torax.torax_pydantic import model_config
import xarray as xr
import os
# Core profiles.
CORE_PROFILES = "core_profiles"
TEMP_EL = "temp_el"
TEMP_EL_RIGHT_BC = "temp_el_right_bc"
TEMP_ION = "temp_ion"
TEMP_ION_RIGHT_BC = "temp_ion_right_bc"
PSI = "psi"
PSIDOT = "psidot"
PSI_RIGHT_GRAD_BC = "psi_right_grad_bc"
PSI_RIGHT_BC = "psi_right_bc"
NE = "ne"
NE_RIGHT_BC = "ne_right_bc"
NI = "ni"
NI_RIGHT_BC = "ni_right_bc"
JTOT = "jtot"
JTOT_FACE = "jtot_face"
JOHM = "johm"
EXTERNAL_CURRENT = "external_current_source"
J_BOOTSTRAP = "j_bootstrap"
J_BOOTSTRAP_FACE = "j_bootstrap_face"
I_BOOTSTRAP = "I_bootstrap"
SIGMA = "sigma"
Q_FACE = "q_face"
S_FACE = "s_face"
NREF = "nref"
ZIMP = "Zimp"
NIMP = "nimp"
IP_PROFILE_FACE = "Ip_profile_face"
IP_TOTAL = "Ip_total"
VLOOP_LCFS = "vloop_lcfs"
# Core transport.
CORE_TRANSPORT = "core_transport"
CHI_FACE_ION = "chi_face_ion"
CHI_FACE_EL = "chi_face_el"
D_FACE_EL = "d_face_el"
V_FACE_EL = "v_face_el"
# Geometry.
GEOMETRY = "geometry"
# Coordinates.
RHO_FACE_NORM = "rho_face_norm"
RHO_CELL_NORM = "rho_cell_norm"
RHO_FACE = "rho_face"
RHO_CELL = "rho_cell"
TIME = "time"
# Post processed outputs
POST_PROCESSED_OUTPUTS = "post_processed_outputs"
Q_FUSION = "Q_fusion"
# Simulation error state.
SIM_ERROR = "sim_error"
# Sources.
CORE_SOURCES = "core_sources"
# Boolean array indicating whether the state corresponds to a
# post-sawtooth-crash state.
SAWTOOTH_CRASH = "sawtooth_crash"
# ToraxConfig.
CONFIG = "config"
# Excluded coordinates from geometry since they are at the top DataTree level.
# Exclude q_correction_factor as it is not an interesting quantity to save.
# TODO(b/338033916): consolidate on either rho or rho_cell naming for cell grid
EXCLUDED_GEOMETRY_NAMES = frozenset({
RHO_FACE,
RHO_CELL,
RHO_CELL_NORM,
RHO_FACE_NORM,
"rho",
"rho_norm",
"q_correction_factor",
})
def safe_load_dataset(filepath: str) -> xr.DataTree:
with open(filepath, "rb") as f:
with xr.open_datatree(f) as dt_open:
data_tree = dt_open.compute()
return data_tree
[docs]
def load_state_file(
filepath: str,
) -> xr.DataTree:
"""Loads a state file from a filepath."""
if os.path.exists(filepath):
data_tree = safe_load_dataset(filepath)
logging.info("Loading state file %s", filepath)
return data_tree
else:
raise ValueError(f"File {filepath} does not exist.")
[docs]
def concat_datatrees(
tree1: xr.DataTree,
tree2: xr.DataTree,
) -> xr.DataTree:
"""Concats two xr.DataTrees along the time dimension.
For any duplicate time steps, the values from the first dataset are kept.
Args:
tree1: The first xr.DataTree to concatenate.
tree2: The second xr.DataTree to concatenate.
Returns:
A xr.DataTree containing the concatenation of the two datasets.
"""
def _concat_datasets(
previous_ds: xr.Dataset,
ds: xr.Dataset,
) -> xr.Dataset:
"""Concats two xr.Datasets."""
# Do a minimal concat to avoid concatting any non time indexed vars.
ds = xr.concat([previous_ds, ds], dim=TIME, data_vars="minimal")
# Drop any duplicate time steps. Using "first" imposes
# keeping the restart state from the earlier dataset. In the case of TORAX
# restarts this contains more complete information e.g. transport and post
# processed outputs.
ds = ds.drop_duplicates(dim=TIME, keep="first")
return ds
return xr.map_over_datasets(_concat_datasets, tree1, tree2)
[docs]
def stitch_state_files(
file_restart: file_restart_pydantic_model.FileRestart, datatree: xr.DataTree
) -> xr.DataTree:
"""Stitch a datatree to the end of a previous state file.
Args:
file_restart: Contains information on a file this sim was restarted from.
datatree: The xr.DataTree to stitch to the end of the previous state file.
Returns:
A xr.DataTree containing the stitched dataset.
"""
previous_datatree = load_state_file(file_restart.filename)
# Reduce previous_ds to all times before the first time step in this
# sim output. We use ds.time[0] instead of file_restart.time because
# we are uncertain if file_restart.time is the exact time of the
# first time step in this sim output (it takes the nearest time).
previous_datatree = previous_datatree.sel(time=slice(None, datatree.time[0]))
return concat_datatrees(previous_datatree, datatree)
[docs]
class StateHistory:
"""A history of the state of the simulation and its error state."""
def __init__(
self,
state_history: tuple[state.ToraxSimState, ...],
post_processed_outputs_history: tuple[state.PostProcessedOutputs, ...],
sim_error: state.SimError,
source_models: source_models_lib.SourceModels,
torax_config: model_config.ToraxConfig,
):
core_profiles = [state.core_profiles for state in state_history]
core_sources = [state.core_sources for state in state_history]
transport = [state.core_transport for state in state_history]
geometries = [state.geometry for state in state_history]
self.geometry = geometry_lib.stack_geometries(geometries)
stack = lambda *ys: np.stack(ys)
self.core_profiles: state.CoreProfiles = jax.tree_util.tree_map(
stack, *core_profiles
)
self.core_sources: source_profiles.SourceProfiles = jax.tree_util.tree_map(
stack, *core_sources
)
self.core_transport: state.CoreTransport = jax.tree_util.tree_map(
stack, *transport
)
self.post_processed_outputs: state.PostProcessedOutputs = (
jax.tree_util.tree_map(stack, *post_processed_outputs_history)
)
self.times = np.array([state.t for state in state_history])
# The rho grid does not change in time so we can just take the first one.
self.rho_norm = state_history[0].geometry.rho_norm
self.rho_face_norm = state_history[0].geometry.rho_face_norm
chex.assert_rank(self.times, 1)
self.sim_error = sim_error
self.source_models = source_models
self.sawtooth_crash = np.array(
[state.sawtooth_crash for state in state_history]
)
self.torax_config = torax_config
def _pack_into_data_array(
self,
name: str,
data: jax.Array | None,
) -> xr.DataArray | None:
"""Packs the data into an xr.DataArray."""
if data is None:
return None
is_face_var = lambda x: x.ndim == 2 and x.shape == (
len(self.times),
len(self.rho_face_norm),
)
is_cell_var = lambda x: x.ndim == 2 and x.shape == (
len(self.times),
len(self.rho_norm),
)
is_scalar = lambda x: x.ndim == 1 and x.shape == (len(self.times),)
is_constant = lambda x: x.ndim == 0
match data:
case data if is_face_var(data):
dims = [TIME, RHO_FACE_NORM]
case data if is_cell_var(data):
dims = [TIME, RHO_CELL_NORM]
case data if is_scalar(data):
dims = [TIME]
case data if is_constant(data):
dims = []
case _:
logging.warning(
"Unsupported data shape for %s: %s. Skipping persisting.",
name,
data.shape, # pytype: disable=attribute-error
)
return None
return xr.DataArray(data, dims=dims, name=name)
def _get_core_profiles(
self,
) -> dict[str, xr.DataArray | None]:
"""Saves the core profiles to a dict."""
xr_dict = {}
xr_dict[TEMP_EL] = self.core_profiles.temp_el.value
xr_dict[TEMP_EL_RIGHT_BC] = self.core_profiles.temp_el.right_face_constraint
xr_dict[TEMP_ION] = self.core_profiles.temp_ion.value
xr_dict[TEMP_ION_RIGHT_BC] = (
self.core_profiles.temp_ion.right_face_constraint
)
xr_dict[PSI] = self.core_profiles.psi.value
xr_dict[PSI_RIGHT_GRAD_BC] = (
self.core_profiles.psi.right_face_grad_constraint
)
xr_dict[PSI_RIGHT_BC] = self.core_profiles.psi.right_face_constraint
xr_dict[PSIDOT] = self.core_profiles.psidot.value
xr_dict[NE] = self.core_profiles.ne.value
xr_dict[NE_RIGHT_BC] = self.core_profiles.ne.right_face_constraint
xr_dict[NI] = self.core_profiles.ni.value
xr_dict[NI_RIGHT_BC] = self.core_profiles.ni.right_face_constraint
xr_dict[ZIMP] = self.core_profiles.Zimp
xr_dict[NIMP] = self.core_profiles.nimp.value
# Currents.
xr_dict[JTOT] = self.core_profiles.currents.jtot
xr_dict[JTOT_FACE] = self.core_profiles.currents.jtot_face
xr_dict[JOHM] = self.core_profiles.currents.johm
xr_dict[EXTERNAL_CURRENT] = (
self.core_profiles.currents.external_current_source
)
xr_dict[J_BOOTSTRAP] = self.core_profiles.currents.j_bootstrap
xr_dict[J_BOOTSTRAP_FACE] = self.core_profiles.currents.j_bootstrap_face
xr_dict[IP_PROFILE_FACE] = self.core_profiles.currents.Ip_profile_face
xr_dict[IP_TOTAL] = self.core_profiles.currents.Ip_total
xr_dict[I_BOOTSTRAP] = self.core_profiles.currents.I_bootstrap
xr_dict[SIGMA] = self.core_profiles.currents.sigma
xr_dict[Q_FACE] = self.core_profiles.q_face
xr_dict[S_FACE] = self.core_profiles.s_face
xr_dict[NREF] = self.core_profiles.nref
xr_dict[VLOOP_LCFS] = self.core_profiles.vloop_lcfs
xr_dict = {
name: self._pack_into_data_array(
name,
data,
)
for name, data in xr_dict.items()
}
return xr_dict
def _save_core_transport(
self,
) -> dict[str, xr.DataArray | None]:
"""Saves the core transport to a dict."""
xr_dict = {}
xr_dict[CHI_FACE_ION] = self.core_transport.chi_face_ion
xr_dict[CHI_FACE_EL] = self.core_transport.chi_face_el
xr_dict[D_FACE_EL] = self.core_transport.d_face_el
xr_dict[V_FACE_EL] = self.core_transport.v_face_el
# Save optional BohmGyroBohm attributes if nonzero.
core_transport = self.core_transport
if (
np.any(core_transport.chi_face_el_bohm != 0)
or np.any(core_transport.chi_face_el_gyrobohm != 0)
or np.any(core_transport.chi_face_ion_bohm != 0)
or np.any(core_transport.chi_face_ion_gyrobohm != 0)
):
xr_dict["chi_face_el_bohm"] = core_transport.chi_face_el_bohm
xr_dict["chi_face_el_gyrobohm"] = core_transport.chi_face_el_gyrobohm
xr_dict["chi_face_ion_bohm"] = core_transport.chi_face_ion_bohm
xr_dict["chi_face_ion_gyrobohm"] = core_transport.chi_face_ion_gyrobohm
xr_dict = {
name: self._pack_into_data_array(
name,
data,
)
for name, data in xr_dict.items()
}
return xr_dict
def _save_core_sources(
self,
) -> dict[str, xr.DataArray | None]:
"""Saves the core sources to a dict."""
xr_dict = {}
xr_dict[self.source_models.qei_source_name] = (
self.core_sources.qei.qei_coef
* (self.core_profiles.temp_el.value - self.core_profiles.temp_ion.value)
)
# Add source profiles with suffixes indicating which profile they affect.
for profile in self.core_sources.temp_ion:
xr_dict[f"{profile}_ion"] = self.core_sources.temp_ion[profile]
for profile in self.core_sources.temp_el:
xr_dict[f"{profile}_el"] = self.core_sources.temp_el[profile]
for profile in self.core_sources.psi:
xr_dict[f"{profile}_j"] = self.core_sources.psi[profile]
for profile in self.core_sources.ne:
xr_dict[f"{profile}_ne"] = self.core_sources.ne[profile]
xr_dict = {
name: self._pack_into_data_array(name, data)
for name, data in xr_dict.items()
}
return xr_dict
def _save_post_processed_outputs(
self,
) -> dict[str, xr.DataArray | None]:
"""Saves the post processed outputs to a dict."""
xr_dict = {}
for field_name, data in dataclasses.asdict(
self.post_processed_outputs
).items():
xr_dict[field_name] = self._pack_into_data_array(field_name, data)
return xr_dict
def _save_geometry(
self,
) -> dict[str, xr.DataArray]:
"""Save geometry to a dict. We skip over hires and non-array quantities."""
xr_dict = {}
# Get the variables from dataclass fields.
for field_name, data in dataclasses.asdict(self.geometry).items():
if (
"hires" in field_name
or field_name == "geometry_type"
or field_name == "Ip_from_parameters"
or not isinstance(data, jax.Array)
):
continue
data_array = self._pack_into_data_array(
field_name,
data,
)
if data_array is not None:
xr_dict[field_name] = data_array
# Get variables from property methods
for name, value in inspect.getmembers(type(self.geometry)):
if name in EXCLUDED_GEOMETRY_NAMES:
continue
if isinstance(value, property):
property_data = value.fget(self.geometry)
data_array = self._pack_into_data_array(name, property_data)
if data_array is not None:
xr_dict[name] = data_array
return xr_dict
[docs]
def simulation_output_to_xr(
self,
file_restart: file_restart_pydantic_model.FileRestart | None = None,
) -> xr.DataTree:
"""Build an xr.DataTree of the simulation output.
Args:
file_restart: If provided, contains information on a file this sim was
restarted from, this is useful in case we want to stitch that to the
beggining of this sim output.
Returns:
A xr.DataTree containing a single top level xr.Dataset and four child
datasets. The top level dataset contains the following variables:
- time: The time of the simulation.
- rho_face_norm: The normalized toroidal coordinate on the face grid.
- rho_cell_norm: The normalized toroidal coordinate on the cell grid.
- rho_face: The toroidal coordinate on the face grid.
- rho_cell: The toroidal coordinate on the cell grid.
- sawtooth_crash: Time-series boolean indicating whether the
state corresponds to a post-sawtooth-crash state.
- sim_error: The simulation error state.
- config: The ToraxConfig used to run the simulation serialized to JSON.
The child datasets contain the following variables:
- core_profiles: Contains data variables for quantities in the
CoreProfiles.
- core_transport: Contains data variables for quantities in the
CoreTransport.
- core_sources: Contains data variables for quantities in the
CoreSources.
- post_processed_outputs: Contains data variables for quantities in the
PostProcessedOutputs.
- geometry: Contains data variables for quantities in the Geometry.
"""
# Cleanup structure by excluding QeiInfo from core_sources altogether.
# Add attribute to dataset variables with explanation of contents + units.
# Get coordinate variables for dimensions ("time", "rho_face", "rho_cell")
time = xr.DataArray(self.times, dims=[TIME], name=TIME)
rho_face_norm = xr.DataArray(
self.rho_face_norm, dims=[RHO_FACE_NORM], name=RHO_FACE_NORM
)
rho_cell_norm = xr.DataArray(
self.rho_norm, dims=[RHO_CELL_NORM], name=RHO_CELL_NORM
)
coords = {
TIME: time,
RHO_FACE_NORM: rho_face_norm,
RHO_CELL_NORM: rho_cell_norm,
}
# Update dict with flattened StateHistory dataclass containers
core_profiles_ds = xr.Dataset(self._get_core_profiles(), coords=coords)
core_transport_ds = xr.Dataset(self._save_core_transport(), coords=coords)
core_sources_ds = xr.Dataset(
self._save_core_sources(),
coords=coords,
)
post_processed_outputs_ds = xr.Dataset(
self._save_post_processed_outputs(), coords=coords
)
geometry_ds = xr.Dataset(self._save_geometry(), coords=coords)
top_level_xr_dict = {
SIM_ERROR: self.sim_error.value,
SAWTOOTH_CRASH: xr.DataArray(
self.sawtooth_crash, dims=[TIME], name=SAWTOOTH_CRASH
),
}
data_tree = xr.DataTree(
children={
CORE_PROFILES: xr.DataTree(dataset=core_profiles_ds),
CORE_TRANSPORT: xr.DataTree(dataset=core_transport_ds),
CORE_SOURCES: xr.DataTree(dataset=core_sources_ds),
POST_PROCESSED_OUTPUTS: xr.DataTree(
dataset=post_processed_outputs_ds
),
GEOMETRY: xr.DataTree(dataset=geometry_ds),
},
dataset=xr.Dataset(
top_level_xr_dict,
coords=coords,
attrs={CONFIG: self.torax_config.model_dump_json()},
),
)
if file_restart is not None and file_restart.stitch:
data_tree = stitch_state_files(file_restart, data_tree)
return data_tree