Source code for torax.geometry.geometry_provider

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GeometryProvider interface and implementations.

NOTE: Time dependent providers currently live in `geometry.py` and match the
protocol defined here.
"""
from collections.abc import Mapping
import dataclasses
from typing import Protocol, Type

import chex
import numpy as np
from torax import interpolated_param
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic
import typing_extensions

# Using invalid-name because we are using the same naming convention as the
# external physics implementations
# pylint: disable=invalid-name


[docs] class GeometryProvider(Protocol): """Returns the geometry to use during one time step of the simulation. A GeometryProvider is any callable (class or function) which takes the time of a time step and returns the Geometry for that time step. See `SimulationStepFn` for how this callable is used. This class is a typing.Protocol, meaning it defines an interface, but any function asking for a GeometryProvider as an argument can accept any function or class that implements this API without specifically extending this class. For instance, the following is an equivalent implementation of the ConstantGeometryProvider without actually creating a class, and equally valid. .. code-block:: python geo = circular_geometry.build_circular_geometry(...) constant_geo_provider = lamdba t: geo def func_expecting_geo_provider(gp: GeometryProvider): ... # do something with the provider. func_expecting_geo_provider(constant_geo_provider) # this works. NOTE: In order to maintain consistency between the DynamicRuntimeParamsSlice and the geometry, `sim.get_consistent_dynamic_runtime_params_slice_and_geometry` should be used to get a Geometry and a corresponding DynamicRuntimeParamsSlice. """ def __call__( self, t: chex.Numeric, ) -> geometry.Geometry: """Returns the geometry to use during one time step of the simulation. The geometry may change from time step to time step, so the sim needs a callable to provide which geometry to use for a given time step (this is that callable). Args: t: The time at which the geometry is being requested. Returns: Geometry of the torus to use for the time step. """ @property def torax_mesh(self) -> torax_pydantic.Grid1D: """Returns the mesh used by Torax, this is consistent across time."""
[docs] class ConstantGeometryProvider(GeometryProvider): """Returns the same Geometry for all calls.""" def __init__(self, geo: geometry.Geometry): self._geo = geo def __call__(self, t: chex.Numeric) -> geometry.Geometry: # The API includes time as an arg even though it is unused in order # to match the API of a GeometryProvider. del t # Ignored. return self._geo @property def torax_mesh(self) -> torax_pydantic.Grid1D: return self._geo.torax_mesh
[docs] @chex.dataclass(frozen=True) class TimeDependentGeometryProvider: """A geometry provider which holds values to interpolate based on time.""" geometry_type: geometry.GeometryType torax_mesh: torax_pydantic.Grid1D drho_norm: interpolated_param.InterpolatedVarSingleAxis Phi: interpolated_param.InterpolatedVarSingleAxis Phi_face: interpolated_param.InterpolatedVarSingleAxis Rmaj: interpolated_param.InterpolatedVarSingleAxis Rmin: interpolated_param.InterpolatedVarSingleAxis B0: interpolated_param.InterpolatedVarSingleAxis volume: interpolated_param.InterpolatedVarSingleAxis volume_face: interpolated_param.InterpolatedVarSingleAxis area: interpolated_param.InterpolatedVarSingleAxis area_face: interpolated_param.InterpolatedVarSingleAxis vpr: interpolated_param.InterpolatedVarSingleAxis vpr_face: interpolated_param.InterpolatedVarSingleAxis spr: interpolated_param.InterpolatedVarSingleAxis spr_face: interpolated_param.InterpolatedVarSingleAxis delta_face: interpolated_param.InterpolatedVarSingleAxis elongation: interpolated_param.InterpolatedVarSingleAxis elongation_face: interpolated_param.InterpolatedVarSingleAxis g0: interpolated_param.InterpolatedVarSingleAxis g0_face: interpolated_param.InterpolatedVarSingleAxis g1: interpolated_param.InterpolatedVarSingleAxis g1_face: interpolated_param.InterpolatedVarSingleAxis g2: interpolated_param.InterpolatedVarSingleAxis g2_face: interpolated_param.InterpolatedVarSingleAxis g3: interpolated_param.InterpolatedVarSingleAxis g3_face: interpolated_param.InterpolatedVarSingleAxis g2g3_over_rhon: interpolated_param.InterpolatedVarSingleAxis g2g3_over_rhon_face: interpolated_param.InterpolatedVarSingleAxis g2g3_over_rhon_hires: interpolated_param.InterpolatedVarSingleAxis F: interpolated_param.InterpolatedVarSingleAxis F_face: interpolated_param.InterpolatedVarSingleAxis F_hires: interpolated_param.InterpolatedVarSingleAxis Rin: interpolated_param.InterpolatedVarSingleAxis Rin_face: interpolated_param.InterpolatedVarSingleAxis Rout: interpolated_param.InterpolatedVarSingleAxis Rout_face: interpolated_param.InterpolatedVarSingleAxis spr_hires: interpolated_param.InterpolatedVarSingleAxis rho_hires_norm: interpolated_param.InterpolatedVarSingleAxis rho_hires: interpolated_param.InterpolatedVarSingleAxis _z_magnetic_axis: interpolated_param.InterpolatedVarSingleAxis | None
[docs] @classmethod def create_provider( cls, geometries: Mapping[float, geometry.Geometry] ) -> typing_extensions.Self: """Creates a GeometryProvider from a mapping of times to geometries.""" # Create a list of times and geometries. times = np.asarray(list(geometries.keys())) geos = list(geometries.values()) initial_geometry = geos[0] for geo in geos: if geo.geometry_type != initial_geometry.geometry_type: raise ValueError('All geometries must have the same geometry type.') if geo.torax_mesh != initial_geometry.torax_mesh: raise ValueError('All geometries must have the same mesh.') # Create a list of interpolated parameters for each geometry attribute. kwargs = { 'geometry_type': initial_geometry.geometry_type, 'torax_mesh': initial_geometry.torax_mesh, } if hasattr(initial_geometry, 'Ip_from_parameters'): kwargs['Ip_from_parameters'] = initial_geometry.Ip_from_parameters for attr in dataclasses.fields(cls): if ( attr.name == 'geometry_type' or attr.name == 'torax_mesh' or attr.name == 'Ip_from_parameters' ): continue if attr.name == '_z_magnetic_axis': if initial_geometry._z_magnetic_axis is None: # pylint: disable=protected-access kwargs[attr.name] = None continue kwargs[attr.name] = interpolated_param.InterpolatedVarSingleAxis( (times, np.stack([getattr(g, attr.name) for g in geos], axis=0)) ) return cls(**kwargs)
def _get_geometry_base( self, t: chex.Numeric, geometry_class: Type[geometry.Geometry] ): """Returns a Geometry instance of the given type at the given time.""" kwargs = { 'geometry_type': self.geometry_type, 'torax_mesh': self.torax_mesh, } if hasattr(self, 'Ip_from_parameters'): kwargs['Ip_from_parameters'] = self.Ip_from_parameters for attr in dataclasses.fields(geometry_class): if ( attr.name == 'geometry_type' or attr.name == 'torax_mesh' or attr.name == 'Ip_from_parameters' ): continue # always initialize Phibdot as zero. It will be replaced once both geo_t # and geo_t_plus_dt are provided, and set to be the same for geo_t and # geo_t_plus_dt for each given time interval. if attr.name == 'Phibdot': kwargs[attr.name] = 0.0 continue if attr.name == '_z_magnetic_axis': if self._z_magnetic_axis is None: kwargs[attr.name] = None continue kwargs[attr.name] = getattr(self, attr.name).get_value(t) return geometry_class(**kwargs) # pytype: disable=wrong-keyword-args def __call__(self, t: chex.Numeric) -> geometry.Geometry: """Returns a Geometry instance at the given time.""" return self._get_geometry_base(t, geometry.Geometry)