Source code for torax.fvm.calc_coeffs

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Calculates Block1DCoeffs for a time step."""
import functools

import jax
import jax.numpy as jnp
from torax import constants
from torax import jax_utils
from torax import state
from torax.config import runtime_params_slice
from torax.core_profiles import updaters
from torax.fvm import block_1d_coeffs
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.transport_model import transport_model as transport_model_lib


[docs] class CoeffsCallback: """Calculates Block1DCoeffs for a state. Attributes: static_runtime_params_slice: See the docstring for `stepper.Stepper`. transport_model: See the docstring for `stepper.Stepper`. explicit_source_profiles: See the docstring for `stepper.Stepper`. source_models: See the docstring for `stepper.Stepper`. evolving_names: The names of the evolving variables. pedestal_model: See the docstring for `stepper.Stepper`. """ def __init__( self, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, transport_model: transport_model_lib.TransportModel, explicit_source_profiles: source_profiles_lib.SourceProfiles, source_models: source_models_lib.SourceModels, evolving_names: tuple[str, ...], pedestal_model: pedestal_model_lib.PedestalModel, ): self.static_runtime_params_slice = static_runtime_params_slice self.transport_model = transport_model self.explicit_source_profiles = explicit_source_profiles self.source_models = source_models self.evolving_names = evolving_names self.pedestal_model = pedestal_model def __call__( self, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, x: tuple[cell_variable.CellVariable, ...], allow_pereverzev: bool = False, # Checks if reduced calc_coeffs for explicit terms when theta_imp=1 # should be called explicit_call: bool = False, ) -> block_1d_coeffs.Block1DCoeffs: """Returns coefficients given a state x. Used to calculate the coefficients for the implicit or explicit components of the PDE system. Args: dynamic_runtime_params_slice: Runtime configuration parameters. These values are potentially time-dependent and should correspond to the time step of the state x. geo: The geometry of the system at this time step. core_profiles: The core profiles of the system at this time step. x: The state with cell-grid values of the evolving variables. allow_pereverzev: If True, then the coeffs are being called within a linear solver. Thus could be either the predictor_corrector solver or as part of calculating the initial guess for the nonlinear solver. In either case, we allow the inclusion of Pereverzev-Corrigan terms which aim to stabilize the linear solver when being used with highly nonlinear (stiff) transport coefficients. The nonlinear solver solves the system more rigorously and Pereverzev-Corrigan terms are not needed. explicit_call: If True, then if theta_imp=1, only a reduced Block1DCoeffs is calculated since most explicit coefficients will not be used. Returns: coeffs: The diffusion, convection, etc. coefficients for this state. """ # Update core_profiles with the subset of new values of evolving variables core_profiles = updaters.update_core_profiles_during_step( x, self.static_runtime_params_slice, dynamic_runtime_params_slice, geo, core_profiles, self.evolving_names, ) if allow_pereverzev: use_pereverzev = self.static_runtime_params_slice.stepper.use_pereverzev else: use_pereverzev = False return calc_coeffs( self.static_runtime_params_slice, dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, transport_model=self.transport_model, explicit_source_profiles=self.explicit_source_profiles, source_models=self.source_models, evolving_names=self.evolving_names, use_pereverzev=use_pereverzev, explicit_call=explicit_call, pedestal_model=self.pedestal_model, )
def _calculate_pereverzev_flux( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, pedestal_model_output: pedestal_model_lib.PedestalModelOutput, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: """Adds Pereverzev-Corrigan flux to diffusion terms.""" consts = constants.CONSTANTS true_ne_face = ( core_profiles.ne.face_value() * dynamic_runtime_params_slice.numerics.nref ) true_ni_face = ( core_profiles.ni.face_value() * dynamic_runtime_params_slice.numerics.nref ) geo_factor = jnp.concatenate( [jnp.ones(1), geo.g1_over_vpr_face[1:] / geo.g0_face[1:]] ) chi_face_per_ion = ( geo.g1_over_vpr_face * true_ni_face * consts.keV2J * dynamic_runtime_params_slice.stepper.chi_per ) chi_face_per_el = ( geo.g1_over_vpr_face * true_ne_face * consts.keV2J * dynamic_runtime_params_slice.stepper.chi_per ) d_face_per_el = dynamic_runtime_params_slice.stepper.d_per v_face_per_el = ( core_profiles.ne.face_grad() / core_profiles.ne.face_value() * d_face_per_el * geo_factor ) # remove Pereverzev flux from boundary region if pedestal model is on # (for PDE stability) chi_face_per_ion = jnp.where( geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top, 0.0, chi_face_per_ion, ) chi_face_per_el = jnp.where( geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top, 0.0, chi_face_per_el, ) # set heat convection terms to zero out Pereverzev-Corrigan heat diffusion v_heat_face_ion = ( core_profiles.temp_ion.face_grad() / core_profiles.temp_ion.face_value() * chi_face_per_ion ) v_heat_face_el = ( core_profiles.temp_el.face_grad() / core_profiles.temp_el.face_value() * chi_face_per_el ) d_face_per_el = jnp.where( geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top, 0.0, d_face_per_el * geo.g1_over_vpr_face, ) v_face_per_el = jnp.where( geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top, 0.0, v_face_per_el * geo.g0_face, ) chi_face_per_ion = chi_face_per_ion.at[0].set(chi_face_per_ion[1]) chi_face_per_el = chi_face_per_el.at[0].set(chi_face_per_el[1]) return ( chi_face_per_ion, chi_face_per_el, v_heat_face_ion, v_heat_face_el, d_face_per_el, v_face_per_el, )
[docs] def calc_coeffs( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, transport_model: transport_model_lib.TransportModel, explicit_source_profiles: source_profiles_lib.SourceProfiles, source_models: source_models_lib.SourceModels, pedestal_model: pedestal_model_lib.PedestalModel, evolving_names: tuple[str, ...], use_pereverzev: bool = False, explicit_call: bool = False, ) -> block_1d_coeffs.Block1DCoeffs: """Calculates Block1DCoeffs for the time step described by `core_profiles`. Args: static_runtime_params_slice: General input parameters which are fixed through a simulation run, and if changed, would trigger a recompile. dynamic_runtime_params_slice: General input parameters that can change from time step to time step or simulation run to run, and do so without triggering a recompile. geo: Geometry describing the torus. core_profiles: Core plasma profiles for this time step during this iteration of the solver. Depending on the type of stepper being used, this may or may not be equal to the original plasma profiles at the beginning of the time step. transport_model: A TransportModel subclass, calculates transport coeffs. explicit_source_profiles: Precomputed explicit source profiles. These profiles either do not depend on the core profiles or depend on the original core profiles at the start of the time step, not the "live" updating core profiles. For sources that are implicit, their explicit profiles are set to all zeros. source_models: All TORAX source/sink functions that generate the explicit and implicit source profiles used as terms for the core profiles equations. pedestal_model: A PedestalModel subclass, calculates pedestal values. evolving_names: The names of the evolving variables in the order that their coefficients should be written to `coeffs`. use_pereverzev: Toggle whether to calculate Pereverzev terms explicit_call: If True, indicates that calc_coeffs is being called for the explicit component of the PDE. Then calculates a reduced Block1DCoeffs if theta_imp=1. This saves computation for the default fully implicit implementation. Returns: coeffs: Block1DCoeffs containing the coefficients at this time step. """ # If we are fully implicit and we are making a call for calc_coeffs for the # explicit components of the PDE, only return a cheaper reduced Block1DCoeffs if explicit_call and static_runtime_params_slice.stepper.theta_imp == 1.0: return _calc_coeffs_reduced( geo, core_profiles, evolving_names, ) else: return _calc_coeffs_full( static_runtime_params_slice, dynamic_runtime_params_slice, geo, core_profiles, transport_model, explicit_source_profiles, source_models, pedestal_model, evolving_names, use_pereverzev, )
@functools.partial( jax_utils.jit, static_argnames=[ 'static_runtime_params_slice', 'transport_model', 'pedestal_model', 'source_models', 'evolving_names', ], ) def _calc_coeffs_full( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, transport_model: transport_model_lib.TransportModel, explicit_source_profiles: source_profiles_lib.SourceProfiles, source_models: source_models_lib.SourceModels, pedestal_model: pedestal_model_lib.PedestalModel, evolving_names: tuple[str, ...], use_pereverzev: bool = False, ) -> block_1d_coeffs.Block1DCoeffs: """Calculates Block1DCoeffs for the time step described by `core_profiles`. Args: static_runtime_params_slice: General input parameters which are fixed through a simulation run, and if changed, would trigger a recompile. dynamic_runtime_params_slice: General input parameters that can change from time step to time step or simulation run to run, and do so without triggering a recompile. geo: Geometry describing the torus. core_profiles: Core plasma profiles for this time step during this iteration of the solver. Depending on the type of stepper being used, this may or may not be equal to the original plasma profiles at the beginning of the time step. transport_model: A TransportModel subclass, calculates transport coeffs. explicit_source_profiles: Precomputed explicit source profiles. These profiles either do not depend on the core profiles or depend on the original core profiles at the start of the time step, not the "live" updating core profiles. For sources that are implicit, their explicit profiles are set to all zeros. source_models: All TORAX source/sink functions that generate the explicit and implicit source profiles used as terms for the core profiles equations. pedestal_model: A PedestalModel subclass, calculates pedestal values. evolving_names: The names of the evolving variables in the order that their coefficients should be written to `coeffs`. use_pereverzev: Toggle whether to calculate Pereverzev terms Returns: coeffs: Block1DCoeffs containing the coefficients at this time step. """ consts = constants.CONSTANTS pedestal_model_output = pedestal_model( dynamic_runtime_params_slice, geo, core_profiles ) # Boolean mask for enforcing internal temperature boundary conditions to # model the pedestal. # If rho_norm_ped_top_idx is outside of bounds of the mesh, the pedestal is # not present and the mask is all False. This is what is used in the case that # set_pedestal is False. mask = ( jnp.zeros_like(geo.rho, dtype=bool) .at[pedestal_model_output.rho_norm_ped_top_idx] .set(True) ) # Calculate the implicit source profiles and combines with the explicit merged_source_profiles = source_profile_builders.build_source_profiles( source_models=source_models, dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, geo=geo, core_profiles=core_profiles, explicit=False, explicit_source_profiles=explicit_source_profiles, ) # psi source terms. Source matrix is zero for all psi sources source_mat_psi = jnp.zeros_like(geo.rho) # fill source vector based on both original and updated core profiles source_psi = merged_source_profiles.total_psi_sources(geo) true_ne = core_profiles.ne.value * dynamic_runtime_params_slice.numerics.nref true_ni = core_profiles.ni.value * dynamic_runtime_params_slice.numerics.nref true_ne_face = ( core_profiles.ne.face_value() * dynamic_runtime_params_slice.numerics.nref ) true_ni_face = ( core_profiles.ni.face_value() * dynamic_runtime_params_slice.numerics.nref ) # Transient term coefficient vector (has radial dependence through r, n) toc_temp_ion = ( 1.5 * geo.vpr ** (-2.0 / 3.0) * consts.keV2J * dynamic_runtime_params_slice.numerics.nref ) tic_temp_ion = core_profiles.ni.value * geo.vpr ** (5.0 / 3.0) toc_temp_el = ( 1.5 * geo.vpr ** (-2.0 / 3.0) * consts.keV2J * dynamic_runtime_params_slice.numerics.nref ) tic_temp_el = core_profiles.ne.value * geo.vpr ** (5.0 / 3.0) toc_psi = ( 1.0 / dynamic_runtime_params_slice.numerics.resistivity_mult * geo.rho_norm * merged_source_profiles.j_bootstrap.sigma * consts.mu0 * 16 * jnp.pi**2 * geo.Phib**2 / geo.F**2 ) tic_psi = jnp.ones_like(toc_psi) toc_dens_el = jnp.ones_like(geo.vpr) tic_dens_el = geo.vpr # Diffusion term coefficients transport_coeffs = transport_model( dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_output ) chi_face_ion = transport_coeffs.chi_face_ion chi_face_el = transport_coeffs.chi_face_el d_face_el = transport_coeffs.d_face_el v_face_el = transport_coeffs.v_face_el d_face_psi = geo.g2g3_over_rhon_face if static_runtime_params_slice.dens_eq: if d_face_el is None or v_face_el is None: raise NotImplementedError( f'{type(transport_model)} does not support the density equation.' ) # entire coefficient preceding dT/dr in heat transport equations full_chi_face_ion = ( geo.g1_over_vpr_face * true_ni_face * consts.keV2J * chi_face_ion ) full_chi_face_el = ( geo.g1_over_vpr_face * true_ne_face * consts.keV2J * chi_face_el ) # entire coefficient preceding dne/dr in particle equation full_d_face_el = geo.g1_over_vpr_face * d_face_el full_v_face_el = geo.g0_face * v_face_el # density source terms. Initialize source matrix to zero source_mat_nn = jnp.zeros_like(geo.rho) # density source vector based both on original and updated core profiles source_ne = merged_source_profiles.total_sources('ne', geo) source_ne += ( mask * dynamic_runtime_params_slice.numerics.largeValue_n * pedestal_model_output.neped ) source_mat_nn += -(mask * dynamic_runtime_params_slice.numerics.largeValue_n) # Pereverzev-Corrigan correction for heat and particle transport # (deals with stiff nonlinearity of transport coefficients) # TODO(b/311653933) this forces us to include value 0 # convection terms in discrete system, slowing compilation down by ~10%. # See if can improve with a different pattern. ( chi_face_per_ion, chi_face_per_el, v_heat_face_ion, v_heat_face_el, d_face_per_el, v_face_per_el, ) = jax.lax.cond( use_pereverzev, lambda: _calculate_pereverzev_flux( dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_output, ), lambda: tuple([jnp.zeros_like(geo.rho_face)] * 6), ) full_chi_face_ion += chi_face_per_ion full_chi_face_el += chi_face_per_el full_d_face_el += d_face_per_el full_v_face_el += v_face_per_el # Add phibdot terms to heat transport convection v_heat_face_ion += ( -3.0 / 4.0 * geo.Phibdot / geo.Phib * geo.rho_face_norm * geo.vpr_face * true_ni_face * consts.keV2J ) v_heat_face_el += ( -3.0 / 4.0 * geo.Phibdot / geo.Phib * geo.rho_face_norm * geo.vpr_face * true_ne_face * consts.keV2J ) # Add phibdot terms to particle transport convection full_v_face_el += ( -1.0 / 2.0 * geo.Phibdot / geo.Phib * geo.rho_face_norm * geo.vpr_face ) # Add phibdot terms to poloidal flux convection v_face_psi = ( -8.0 * jnp.pi**2 * consts.mu0 * geo.Phibdot * geo.Phib * merged_source_profiles.j_bootstrap.sigma_face * geo.rho_face_norm**2 / geo.F_face**2 ) # Fill heat transport equation sources. Initialize source matrices to zero source_i = merged_source_profiles.total_sources('temp_ion', geo) source_e = merged_source_profiles.total_sources('temp_el', geo) # Add the Qei effects. qei = merged_source_profiles.qei source_mat_ii = qei.implicit_ii * geo.vpr source_i += qei.explicit_i * geo.vpr source_mat_ee = qei.implicit_ee * geo.vpr source_e += qei.explicit_e * geo.vpr source_mat_ie = qei.implicit_ie * geo.vpr source_mat_ei = qei.implicit_ei * geo.vpr # Pedestal source_i += ( mask * dynamic_runtime_params_slice.numerics.largeValue_T * pedestal_model_output.Tiped ) source_e += ( mask * dynamic_runtime_params_slice.numerics.largeValue_T * pedestal_model_output.Teped ) source_mat_ii -= mask * dynamic_runtime_params_slice.numerics.largeValue_T source_mat_ee -= mask * dynamic_runtime_params_slice.numerics.largeValue_T # Add effective phibdot heat source terms # second derivative of volume profile with respect to r_norm vprpr_norm = jnp.gradient(geo.vpr, geo.rho_norm) source_i += ( 1.0 / 2.0 * vprpr_norm * geo.Phibdot / geo.Phib * geo.rho_norm * true_ni * core_profiles.temp_ion.value * consts.keV2J ) source_e += ( 1.0 / 2.0 * vprpr_norm * geo.Phibdot / geo.Phib * geo.rho_norm * true_ne * core_profiles.temp_el.value * consts.keV2J ) # Add effective phibdot poloidal flux source term ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient( merged_source_profiles.j_bootstrap.sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm, ) source_psi += ( -8.0 * jnp.pi**2 * consts.mu0 * geo.Phibdot * geo.Phib * ddrnorm_sigma_rnorm2_over_f2 ) # Build arguments to solver based on which variables are evolving var_to_toc = { 'temp_ion': toc_temp_ion, 'temp_el': toc_temp_el, 'psi': toc_psi, 'ne': toc_dens_el, } var_to_tic = { 'temp_ion': tic_temp_ion, 'temp_el': tic_temp_el, 'psi': tic_psi, 'ne': tic_dens_el, } transient_out_cell = tuple(var_to_toc[var] for var in evolving_names) transient_in_cell = tuple(var_to_tic[var] for var in evolving_names) var_to_d_face = { 'temp_ion': full_chi_face_ion, 'temp_el': full_chi_face_el, 'psi': d_face_psi, 'ne': full_d_face_el, } d_face = tuple(var_to_d_face[var] for var in evolving_names) var_to_v_face = { 'temp_ion': v_heat_face_ion, 'temp_el': v_heat_face_el, 'psi': v_face_psi, 'ne': full_v_face_el, } v_face = tuple(var_to_v_face.get(var) for var in evolving_names) # d maps (row var, col var) to the coefficient for that block of the matrix # (Can't use a descriptive name or the nested comprehension to build the # matrix gets too long) d = { ('temp_ion', 'temp_ion'): source_mat_ii, ('temp_ion', 'temp_el'): source_mat_ie, ('temp_el', 'temp_ion'): source_mat_ei, ('temp_el', 'temp_el'): source_mat_ee, ('ne', 'ne'): source_mat_nn, ('psi', 'psi'): source_mat_psi, } source_mat_cell = tuple( tuple(d.get((row_block, col_block)) for col_block in evolving_names) for row_block in evolving_names ) var_to_source = { 'temp_ion': source_i, 'temp_el': source_e, 'psi': source_psi, 'ne': source_ne, } source_cell = tuple(var_to_source.get(var) for var in evolving_names) coeffs = block_1d_coeffs.Block1DCoeffs( transient_out_cell=transient_out_cell, transient_in_cell=transient_in_cell, d_face=d_face, v_face=v_face, source_mat_cell=source_mat_cell, source_cell=source_cell, auxiliary_outputs=(merged_source_profiles, transport_coeffs), ) return coeffs @functools.partial( jax_utils.jit, static_argnames=[ 'evolving_names', ], ) def _calc_coeffs_reduced( geo: geometry.Geometry, core_profiles: state.CoreProfiles, evolving_names: tuple[str, ...], ) -> block_1d_coeffs.Block1DCoeffs: """Calculates only the transient_in_cell terms in Block1DCoeffs.""" # Only transient_in_cell is used for explicit terms if theta_imp=1 tic_temp_ion = core_profiles.ni.value * geo.vpr ** (5.0 / 3.0) tic_temp_el = core_profiles.ne.value * geo.vpr ** (5.0 / 3.0) tic_psi = jnp.ones_like(geo.vpr) tic_dens_el = geo.vpr var_to_tic = { 'temp_ion': tic_temp_ion, 'temp_el': tic_temp_el, 'psi': tic_psi, 'ne': tic_dens_el, } transient_in_cell = tuple(var_to_tic[var] for var in evolving_names) coeffs = block_1d_coeffs.Block1DCoeffs( transient_in_cell=transient_in_cell, ) return coeffs