Source code for torax.stepper.stepper

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The Stepper class.

Abstract base class defining updates to State.
"""

import abc

import jax
from torax import state
from torax.config import runtime_params_slice
from torax.core_profiles import updaters
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles
from torax.transport_model import transport_model as transport_model_lib


[docs] class Stepper(abc.ABC): """Calculates a single time step's update to State. Attributes: transport_model: A TransportModel subclass, calculates transport coeffs. source_models: All TORAX sources used to compute both the explicit and implicit source profiles used for each time step as terms in the state evolution equations. Though the explicit profiles are computed outside the call to Stepper, the same sources should be used to compute those. The Sources are exposed here to provide a single source of truth for which sources are used during a run. pedestal_model: A PedestalModel subclass, calculates pedestal values. """ def __init__( self, transport_model: transport_model_lib.TransportModel, source_models: source_models_lib.SourceModels, pedestal_model: pedestal_model_lib.PedestalModel, ): self.transport_model = transport_model self.source_models = source_models self.pedestal_model = pedestal_model def __call__( self, dt: jax.Array, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo_t: geometry.Geometry, geo_t_plus_dt: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, ) -> tuple[ state.CoreProfiles, source_profiles.SourceProfiles, state.CoreTransport, state.StepperNumericOutputs, ]: """Applies a time step update. Args: dt: Time step duration. static_runtime_params_slice: Input params that trigger recompilation when they change. These don't have to be JAX-friendly types and can be used in control-flow logic. dynamic_runtime_params_slice_t: Runtime parameters for time t (the start time of the step). These runtime params can change from step to step without triggering a recompilation. dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + dt, used for implicit calculations in the solver. geo_t: Geometry of the torus at time t. geo_t_plus_dt: Geometry of the torus at time t + dt. core_profiles_t: Core plasma profiles at the beginning of the time step. core_profiles_t_plus_dt: Core plasma profiles which contain all available prescribed quantities at the end of the time step. This includes evolving boundary conditions and prescribed time-dependent profiles that are not being evolved by the PDE system. explicit_source_profiles: Source profiles of all explicit sources (as configured by the input params). All implicit source's profiles will be set to 0 in this object. These explicit source profiles were calculated either based on the original core profiles at the start of the time step or were independent of the core profiles. Because they were calculated outside the possibly-JAX-jitted JointStateStepperCallable, they can be calculated in non-JAX-friendly ways. Returns: new_core_profiles: Updated core profiles. core_sources: Merged source profiles of all sources, including explicit and implicit. This is the version of the source profiles that is used to calculate the coefficients for the t+dt time step. For the explicit sources, this is the same as the explicit_source_profiles input. For the implicit sources, this is the most recent guess for time t+dt. core_transport: Transport coefficients for time t+dt. stepper_numeric_output: Error and iteration info. """ # This base class method can be completely overriden by a subclass, but # most can make use of the boilerplate here and just implement `_x_new`. # Use runtime params to determine which variables to evolve evolving_names = [] if static_runtime_params_slice.ion_heat_eq: evolving_names.append('temp_ion') if static_runtime_params_slice.el_heat_eq: evolving_names.append('temp_el') if static_runtime_params_slice.current_eq: evolving_names.append('psi') if static_runtime_params_slice.dens_eq: evolving_names.append('ne') evolving_names = tuple(evolving_names) # Don't call solver functions on an empty list if evolving_names: x_new, core_sources, core_transport, stepper_numeric_output = self._x_new( dt=dt, static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t=geo_t, geo_t_plus_dt=geo_t_plus_dt, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, explicit_source_profiles=explicit_source_profiles, evolving_names=evolving_names, ) else: x_new = tuple() # Calculate implicit source profiles and return the merged version. This # is useful for inspecting prescribed sources in the output state. core_sources = source_profile_builders.build_source_profiles( source_models=self.source_models, dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, static_runtime_params_slice=static_runtime_params_slice, geo=geo_t_plus_dt, core_profiles=core_profiles_t_plus_dt, explicit=False, explicit_source_profiles=explicit_source_profiles, ) core_transport = state.CoreTransport.zeros(geo_t) stepper_numeric_output = state.StepperNumericOutputs() # x_new contains the new cell-grid values of the evolving variables. # Update the core profiles with the new values of the evolving variables and # derived quantities like q_face, psidot, etc. core_profiles_t_plus_dt = updaters.update_all_core_profiles_after_step( x_new, static_runtime_params_slice, dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt, core_sources, core_profiles_t, core_profiles_t_plus_dt, evolving_names, dt=dt, ) return ( core_profiles_t_plus_dt, core_sources, core_transport, stepper_numeric_output, ) def _x_new( self, dt: jax.Array, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo_t: geometry.Geometry, geo_t_plus_dt: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, evolving_names: tuple[str, ...], ) -> tuple[ tuple[cell_variable.CellVariable, ...], source_profiles.SourceProfiles, state.CoreTransport, state.StepperNumericOutputs, ]: """Calculates new values of the changing variables. Subclasses must either implement `_x_new` so that `Stepper.__call__` will work, or implement a different `__call__`. Args: dt: Time step duration. static_runtime_params_slice: Input params that trigger recompilation when they change. These don't have to be JAX-friendly types and can be used in control-flow logic. dynamic_runtime_params_slice_t: Runtime parameters for time t (the start time of the step). These runtime params can change from step to step without triggering a recompilation. dynamic_runtime_params_slice_t_plus_dt: Runtime parameters for time t + dt, used for implicit calculations in the solver. geo_t: Geometry of the torus for time t. geo_t_plus_dt: Geometry of the torus for time t + dt. core_profiles_t: Core plasma profiles at the beginning of the time step. core_profiles_t_plus_dt: Core plasma profiles which contain all available prescribed quantities at the end of the time step. This includes evolving boundary conditions and prescribed time-dependent profiles that are not being evolved by the PDE system. explicit_source_profiles: see the docstring of __call__ evolving_names: The names of core_profiles variables that should evolve. Returns: x_new: The values of the evolving variables at time t + dt. core_sources: see the docstring of __call__ core_transport: Transport coefficients for time t+dt. stepper_numeric_output: Error and iteration info. """ raise NotImplementedError( f'{type(self)} must implement `_x_new` or ' 'implement a different `__call__` that does not' ' need `_x_new`.' )