# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for representing the problem geometry."""
from collections.abc import Sequence
import dataclasses
import enum
import chex
import jax
import jax.numpy as jnp
import numpy as np
from torax.torax_pydantic import torax_pydantic
[docs]
def face_to_cell(face: chex.Array) -> chex.Array:
"""Infers cell values corresponding to a vector of face values.
Simply a linear interpolation between face values.
Args:
face: An array containing face values.
Returns:
cell: An array containing cell values.
"""
return 0.5 * (face[:-1] + face[1:])
[docs]
@enum.unique
class GeometryType(enum.IntEnum):
"""Integer enum for geometry type.
This type can be used within JAX expressions to access the geometry type
without having to call isinstance.
"""
CIRCULAR = 0
CHEASE = 1
FBT = 2
EQDSK = 3
# pylint: disable=invalid-name
[docs]
@chex.dataclass(frozen=True)
class Geometry:
r"""Describes the magnetic geometry.
Most users should default to using the StandardGeometry class, whether the
source of their geometry comes from CHEASE, MEQ, EQDSK, etc.
Properties work for both 1D radial arrays and 2D stacked arrays where the
leading dimension is time.
Attributes:
geometry_type: Type of geometry model used. See `GeometryType` for options.
torax_mesh: `Grid1D` object representing the radial mesh used by TORAX.
Phi: Toroidal magnetic flux at each radial grid point [:math:`\mathrm{Wb}`].
Phi_face: Toroidal magnetic flux at each radial face [:math:`\mathrm{Wb}`].
Rmaj: Tokamak major radius (geometric center) [:math:`\mathrm{m}`].
Rmin: Tokamak minor radius [:math:`\mathrm{m}`].
B0: Vacuum toroidal magnetic field on axis [:math:`\mathrm{T}`].
volume: Plasma volume enclosed by each flux surface on cell grid
[:math:`\mathrm{m}^3`].
volume_face: Plasma volume enclosed by each flux surface on face grid
[:math:`\mathrm{m}^3`].
area: Poloidal cross-sectional area of each flux surface on cell grid
[:math:`\mathrm{m}^2`].
area_face: Poloidal cross-sectional area of each flux surface on face grid
[:math:`\mathrm{m}^2`].
vpr: Derivative of plasma volume enclosed by each flux surface with respect
to the normalized toroidal flux coordinate rho_norm on cell grid
[:math:`\mathrm{m}^3`].
vpr_face: Derivative of plasma volume enclosed by each flux surface with
respect to the normalized toroidal flux coordinate rho_face_norm, on face
grid [:math:`\mathrm{m}^3`].
spr: Derivative of plasma surface area enclosed by each flux surface, with
respect to the normalized toroidal flux coordinate rho_norm on cell grid
[:math:`\mathrm{m}^2`]. Equal to vpr / (:math:`2 \pi` Rmaj).
spr_face: Derivative of plasma surface area enclosed by each flux surface,
with respect to the normalized toroidal flux coordinate rho_face_norm on
face grid [:math:`\mathrm{m}^2`]. Equal to vpr_face / (:math:`2 \pi`
Rmaj).
spr_hires: Derivative of plasma surface area enclosed by each flux surface
on a higher resolution grid, with respect to the normalized toroidal flux
coordinate rho_norm. [:math:`\mathrm{m}^2`].
rho_hires: Toroidal flux coordinate on a higher resolution grid
[:math:`\mathrm{m}`].
rho_hires_norm: Normalized toroidal flux coordinate on a higher resolution
grid [dimensionless].
g0: Flux surface averaged radial derivative of the plasma volume:
:math:`\langle \nabla V \rangle` on cell grid [:math:`\mathrm{m}^2`].
g0_face: Flux surface averaged :math:`\langle \nabla V \rangle` on the faces
[:math:`\mathrm{m}^2`].
g1: Flux surface averaged :math:`\langle (\nabla V)^2 \rangle` on cell grid
[:math:`\mathrm{m}^4`].
g1_face: Flux surface averaged :math:`\langle (\nabla V)^2 \rangle` on the
faces [:math:`\mathrm{m}^4`].
g2: Flux surface averaged :math:`\langle (\nabla V)^2 / R^2 \rangle` on cell
grid [:math:`\mathrm{m}^2`], where R is the major radius along the flux
surface being averaged.
g2_face: Flux surface averaged :math:`\langle (\nabla V)^2 / R^2 \rangle` on
the faces [:math:`\mathrm{m}^2`].
g3: Flux surface averaged :math:`\langle 1 / R^2 \rangle` on cell grid
[:math:`\mathrm{m}^{-2}`].
g3_face: Flux surface averaged :math:`\langle 1 / R^2 \rangle` on the faces
[:math:`\mathrm{m}^{-2}`].
g2g3_over_rhon: Ratio of g2g3 to the normalized toroidal flux coordinate
rho_norm on cell grid [dimensionless].
g2g3_over_rhon_face: Ratio of g2g3 to the normalized toroidal flux
coordinate rho_norm on face grid [dimensionless].
g2g3_over_rhon_hires: Ratio of g2g3 to the normalized toroidal flux
coordinate rho_norm on a higher resolution grid [dimensionless].
F: Toroidal field flux function, :math:`F \equiv RB_\phi` on cell grid,
where :math:`R` is major radius, and :math:`B_\phi` is the toroidal
magnetic field [:math:`\mathrm{T m}`].
F_face: Toroidal field flux function, :math:`F \equiv RB_\phi` on face grid
[:math:`\mathrm{T m}`].
F_hires: Toroidal field flux function, :math:`F \equiv RB_\phi` on the high
resolution grid [:math:`\mathrm{T m}`].
Rin: Radius of the flux surface at the inboard side at midplane
[:math:`\mathrm{m}`] on cell grid. Inboard side is defined as the minimum
radial extent of the flux surface.
Rin_face: Radius of the flux surface at the inboard side at midplane
[:math:`\mathrm{m}`] on face grid.
Rout: Radius of the flux surface at the outboard side at midplane
[:math:`\mathrm{m}`] on cell grid. Outboard side is defined as the maximum
radial extent of the flux surface.
Rout_face: Radius of the flux surface at the outboard side at midplane
[:math:`\mathrm{m}`] on face grid.
delta_face: Average of upper and lower triangularity of each flux surface at
the faces [dimensionless]. Upper triangularity is defined as (Rmaj_local -
R_upper) / Rmin_local, where Rmaj_local = (Rout+Rin)/2, Rmin_local =
(Rout-Rin)/2, and R_upper is the radial location of the upper extent of
the flux surface. Lower triangularity is defined as (Rmaj_local - R_lower)
/ Rmin_local, where R_lower is the radial location of the lower extent of
the flux surface.
elongation: Plasma elongation profile on cell grid [dimensionless].
Elongation is defined as (Z_upper - Z_lower) / (2.0 * Rmin_local), where
Z_upper and Z_lower are the Z coordinates of the upper and lower extent of
the flux surface.
elongation_face: Plasma elongation profile on face grid [dimensionless].
Phibdot: Time derivative of the toroidal magnetic flux
[:math:`\mathrm{Wb/s}`]. Calculated across a time interval using ``Phi``
from the Geometry objects at time t and t + dt. See
``torax.orchestration.step_function`` for more details.
_z_magnetic_axis: Vertical position of the magnetic axis
[:math:`\mathrm{m}`].
"""
geometry_type: GeometryType
torax_mesh: torax_pydantic.Grid1D
Phi: chex.Array
Phi_face: chex.Array
Rmaj: chex.Array
Rmin: chex.Array
B0: chex.Array
volume: chex.Array
volume_face: chex.Array
area: chex.Array
area_face: chex.Array
vpr: chex.Array
vpr_face: chex.Array
spr: chex.Array
spr_face: chex.Array
delta_face: chex.Array
elongation: chex.Array
elongation_face: chex.Array
g0: chex.Array
g0_face: chex.Array
g1: chex.Array
g1_face: chex.Array
g2: chex.Array
g2_face: chex.Array
g3: chex.Array
g3_face: chex.Array
g2g3_over_rhon: chex.Array
g2g3_over_rhon_face: chex.Array
g2g3_over_rhon_hires: chex.Array
F: chex.Array
F_face: chex.Array
F_hires: chex.Array
Rin: chex.Array
Rin_face: chex.Array
Rout: chex.Array
Rout_face: chex.Array
spr_hires: chex.Array
rho_hires_norm: chex.Array
rho_hires: chex.Array
Phibdot: chex.Array
_z_magnetic_axis: chex.Array | None
@property
def q_correction_factor(self) -> chex.Numeric:
"""Ad-hoc fix for non-physical circular geometry model.
Set such that q(r=a) = 3 for standard ITER parameters.
"""
return jnp.where(
self.geometry_type == GeometryType.CIRCULAR.value,
1.25,
1,
)
@property
def rho_norm(self) -> chex.Array:
r"""Normalized toroidal flux coordinate on cell grid [dimensionless]."""
return self.torax_mesh.cell_centers
@property
def rho_face_norm(self) -> chex.Array:
r"""Normalized toroidal flux coordinate on face grid [dimensionless]."""
return self.torax_mesh.face_centers
@property
def drho_norm(self) -> chex.Array:
r"""Grid size for rho_norm [dimensionless]."""
return jnp.array(self.torax_mesh.dx)
@property
def rho_face(self) -> chex.Array:
r"""Toroidal flux coordinate on face grid :math:`\mathrm{m}`."""
return self.rho_face_norm * jnp.expand_dims(self.rho_b, axis=-1)
@property
def rho(self) -> chex.Array:
r"""Toroidal flux coordinate on cell grid :math:`\mathrm{m}`.
The toroidal flux coordinate is defined as
:math:`\rho=\sqrt{\frac{\Phi}{\pi B_0}}`, where :math:`\Phi` is the
toroidal flux enclosed by the flux surface, and :math:`B_0` the magnetic
field on the magnetic axis.
"""
return self.rho_norm * jnp.expand_dims(self.rho_b, axis=-1)
@property
def rmid(self) -> chex.Array:
"""Midplane radius of the plasma [m], defined as (Rout-Rin)/2."""
return (self.Rout - self.Rin) / 2
@property
def rmid_face(self) -> chex.Array:
"""Midplane radius of the plasma on the face grid [m]."""
return (self.Rout_face - self.Rin_face) / 2
@property
def drho(self) -> chex.Array:
"""Grid size for rho [m]."""
return self.drho_norm * self.rho_b
@property
def rho_b(self) -> chex.Array:
"""Toroidal flux coordinate [m] at boundary (LCFS)."""
return jnp.sqrt(self.Phib / np.pi / self.B0)
@property
def Phib(self) -> chex.Array:
r"""Toroidal flux at boundary (LCFS) :math:`\mathrm{Wb}`."""
return self.Phi_face[..., -1]
@property
def g1_over_vpr(self) -> chex.Array:
r"""g1/vpr [:math:`\mathrm{m}`]."""
return self.g1 / self.vpr
@property
def g1_over_vpr2(self) -> chex.Array:
r"""g1/vpr**2 [:math:`\mathrm{m}^{-2}`]."""
return self.g1 / self.vpr**2
@property
def g0_over_vpr_face(self) -> jax.Array:
"""g0_face/vpr_face [:math:`m^{-1}`], equal to 1/rho_b on-axis."""
# Calculate the bulk of the array (excluding the first element)
# to avoid division by zero.
bulk = self.g0_face[..., 1:] / self.vpr_face[..., 1:]
first_element = jnp.ones_like(self.rho_b) / self.rho_b
# Concatenate to handle both 1D (no leading dim) and 2D cases
return jnp.concatenate(
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
)
@property
def g1_over_vpr_face(self) -> jax.Array:
r"""g1_face/vpr_face [:math:`\mathrm{m}`]. Zero on-axis."""
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]
first_element = jnp.zeros_like(self.rho_b)
return jnp.concatenate(
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
)
@property
def g1_over_vpr2_face(self) -> jax.Array:
"""g1_face/vpr_face**2 [:math:`m^{-2}`], equal to 1/rho_b**2 on-axis."""
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:] ** 2
first_element = jnp.ones_like(self.rho_b) / self.rho_b**2
return jnp.concatenate(
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
)
[docs]
def z_magnetic_axis(self) -> chex.Numeric:
"""z position of magnetic axis [m]."""
z_magnetic_axis = self._z_magnetic_axis
if z_magnetic_axis is not None:
return z_magnetic_axis
else:
raise ValueError('Geometry does not have a z magnetic axis.')
[docs]
def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
"""Batch together a sequence of geometries.
Args:
geometries: A sequence of geometries to stack. The geometries must have the
same mesh, geometry type.
Returns:
A Geometry object, where each array attribute has an additional
leading axis (e.g. for the time dimension) compared to each Geometry in
the input sequence.
"""
if not geometries:
raise ValueError('No geometries provided.')
# Ensure that all geometries have same mesh and are of same type.
first_geo = geometries[0]
torax_mesh = first_geo.torax_mesh
geometry_type = first_geo.geometry_type
for geometry in geometries[1:]:
if geometry.torax_mesh != torax_mesh:
raise ValueError('All geometries must have the same mesh.')
if geometry.geometry_type != geometry_type:
raise ValueError('All geometries must have the same geometry type.')
stacked_data = {}
for field in dataclasses.fields(first_geo):
field_name = field.name
field_value = getattr(first_geo, field_name)
# Stack stackable fields. Save first geo's value for non-stackable fields.
if isinstance(field_value, chex.Array):
field_values = [getattr(geo, field_name) for geo in geometries]
stacked_data[field_name] = jnp.stack(field_values)
else:
stacked_data[field_name] = field_value
# Create a new object with the stacked data with the same class (i.e.
# could be child classes of Geometry)
return first_geo.__class__(**stacked_data)
[docs]
def update_geometries_with_Phibdot(
*,
dt: chex.Numeric,
geo_t: Geometry,
geo_t_plus_dt: Geometry,
) -> tuple[Geometry, Geometry]:
"""Update Phibdot in the geometry dataclasses used in the time interval.
Phibdot is used in calc_coeffs to calcuate terms related to time-dependent
geometry. It should be set to be the same for geo_t and geo_t_plus_dt for
each given time interval. This means that geo_t_plus_dt.Phibdot will not
necessarily be the same as the geo_t.Phibdot at the next time step.
Args:
dt: Time step duration.
geo_t: The geometry of the torus during this time step of the simulation.
geo_t_plus_dt: The geometry of the torus during the next time step of the
simulation.
Returns:
Tuple containing:
- The geometry of the torus during this time step of the simulation.
- The geometry of the torus during the next time step of the simulation.
"""
Phibdot = (geo_t_plus_dt.Phib - geo_t.Phib) / dt
geo_t = dataclasses.replace(geo_t, Phibdot=Phibdot)
geo_t_plus_dt = dataclasses.replace(geo_t_plus_dt, Phibdot=Phibdot)
return geo_t, geo_t_plus_dt