Source code for torax.physics.psi_calculations

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Calculations related to derived quantities from poloidal flux (psi).

Functions:
    - calc_q: Calculates the q-profile (q).
    - calc_jtot: Calculate flux-surface-averaged toroidal current density.
    - calc_s: Calculates magnetic shear (s).
    - calc_s_rmid: Calculates magnetic shear (s), using midplane r as radial
      coordinate.
    - calc_Wpol: Calculates total magnetic energy (Wpol).
    - calc_li3: Calculates normalized internal inductance li3 (ITER convention).
    - calc_q95: Calculates the q-profile at 95% of the normalized poloidal flux.
    - calculate_psi_grad_constraint_from_Ip_tot: Calculates the gradient
      constraint on the poloidal flux (psi) from Ip.
    - _calc_bpol2: Calculates square of poloidal field (Bp).
"""
import chex
import jax
from jax import numpy as jnp
from torax import array_typing
from torax import constants
from torax import jax_utils
from torax.fvm import cell_variable
from torax.fvm import convection_terms
from torax.fvm import diffusion_terms
from torax.geometry import geometry

_trapz = jax.scipy.integrate.trapezoid

# pylint: disable=invalid-name


[docs] def calc_q_face( geo: geometry.Geometry, psi: cell_variable.CellVariable, ) -> chex.Array: """Calculates the q-profile on the face grid given poloidal flux (psi).""" # iota is standard terminology for 1/q inv_iota = jnp.abs( (2 * geo.Phib * geo.rho_face_norm[1:]) / psi.face_grad()[1:] ) # Use L'Hôpital's rule to calculate iota on-axis, with psi_face_grad()[0]=0. inv_iota0 = jnp.expand_dims( jnp.abs((2 * geo.Phib * geo.drho_norm) / psi.face_grad()[1]), 0 ) q_face = jnp.concatenate([inv_iota0, inv_iota]) return q_face * geo.q_correction_factor
[docs] def calc_jtot( geo: geometry.Geometry, psi: cell_variable.CellVariable, ) -> tuple[chex.Array, chex.Array, chex.Array]: """Calculate flux-surface-averaged toroidal current density from poloidal flux. Calculation based on jtot = dI/dS Args: geo: Torus geometry. psi: Poloidal flux. Returns: jtot: total current density [A/m2] on cell grid jtot_face: total current density [A/m2] on face grid Ip_profile_face: cumulative total plasma current profile [A] on face grid """ # pylint: disable=invalid-name Ip_profile_face = ( psi.face_grad() * geo.g2g3_over_rhon_face * geo.F_face / geo.Phib / (16 * jnp.pi**3 * constants.CONSTANTS.mu0) ) Ip_profile = ( psi.grad() * geo.g2g3_over_rhon * geo.F / geo.Phib / (16 * jnp.pi**3 * constants.CONSTANTS.mu0) ) dI_drhon_face = jnp.gradient(Ip_profile_face, geo.rho_face_norm) dI_drhon = jnp.gradient(Ip_profile, geo.rho_norm) jtot_bulk = dI_drhon[1:] / geo.spr[1:] jtot_face_bulk = dI_drhon_face[1:] / geo.spr_face[1:] # Extrapolate the axis term from the bulk term due to strong sensitivities # of near-axis numerical derivatives. Set zero boundary condition on-axis jtot_axis = jtot_bulk[0] - (jtot_bulk[1] - jtot_bulk[0]) jtot = jnp.concatenate([jnp.array([jtot_axis]), jtot_bulk]) jtot_face = jnp.concatenate([jnp.array([jtot_axis]), jtot_face_bulk]) return jtot, jtot_face, Ip_profile_face
[docs] def calc_s_face( geo: geometry.Geometry, psi: cell_variable.CellVariable ) -> jax.Array: """Calculates magnetic shear on the face grid from poloidal flux (psi).""" # iota (1/q) should have a /2*Phib but we drop it since will cancel out in # the s calculation. iota_scaled = jnp.abs((psi.face_grad()[1:] / geo.rho_face_norm[1:])) # on-axis iota_scaled from L'Hôpital's rule = dpsi_face_grad / drho_norm # Using expand_dims to make it compatible with jnp.concatenate iota_scaled0 = jnp.expand_dims( jnp.abs(psi.face_grad()[1] / geo.drho_norm), axis=0 ) iota_scaled = jnp.concatenate([iota_scaled0, iota_scaled]) s_face = ( -geo.rho_face_norm * jnp.gradient(iota_scaled, geo.rho_face_norm) / iota_scaled ) return s_face
[docs] def calc_s_rmid( geo: geometry.Geometry, psi: cell_variable.CellVariable ) -> jax.Array: """Calculates magnetic shear (s) from poloidal flux (psi). Version taking the derivative of iota with respect to the midplane r, in line with expectations from circular-derived models like QuaLiKiz. Args: geo: Torus geometry. psi: Poloidal flux. Returns: s_face: Magnetic shear, on the face grid. """ # iota (1/q) should have a /2*Phib but we drop it since will cancel out in # the s calculation. iota_scaled = jnp.abs((psi.face_grad()[1:] / geo.rho_face_norm[1:])) # on-axis iota_scaled from L'Hôpital's rule = dpsi_face_grad / drho_norm # Using expand_dims to make it compatible with jnp.concatenate iota_scaled0 = jnp.expand_dims( jnp.abs(psi.face_grad()[1] / geo.drho_norm), axis=0 ) iota_scaled = jnp.concatenate([iota_scaled0, iota_scaled]) rmid_face = (geo.Rout_face - geo.Rin_face) * 0.5 s_face = -rmid_face * jnp.gradient(iota_scaled, rmid_face) / iota_scaled return s_face
def _calc_bpol2( geo: geometry.Geometry, psi: cell_variable.CellVariable ) -> jax.Array: r"""Calculates square of poloidal field (Bp) from poloidal flux (psi). An identity for the poloidal magnetic field is: :math:`B_p = 1/R \partial \psi / \partial \rho (\nabla \rho \times e_\phi)` Where :math:`e_\phi` is the unit vector pointing in the toroidal direction. Args: geo: Torus geometry. psi: Poloidal flux. Returns: bpol2_face: Square of poloidal magnetic field, on the face grid. """ bpol2_bulk = ( (psi.face_grad()[1:] / (2 * jnp.pi)) ** 2 * geo.g2_face[1:] / geo.vpr_face[1:] ** 2 ) bpol2_axis = jnp.array([0.0], dtype=jax_utils.get_dtype()) bpol2_face = jnp.concatenate([bpol2_axis, bpol2_bulk]) return bpol2_face
[docs] def calc_Wpol( geo: geometry.Geometry, psi: cell_variable.CellVariable ) -> jax.Array: """Calculates total magnetic energy (Wpol) from poloidal flux (psi).""" bpol2 = _calc_bpol2(geo, psi) Wpol = _trapz(bpol2 * geo.vpr_face, geo.rho_face_norm) / ( 2 * constants.CONSTANTS.mu0 ) return Wpol
[docs] def calc_li3( Rmaj: jax.Array, Wpol: jax.Array, Ip_total: jax.Array, ) -> jax.Array: """Calculates li3 based on a formulation using Wpol. Normalized internal inductance is defined as: li = <Bpol^2>_V / <Bpol^2>_LCFS where <>_V is a volume average and <>_LCFS is the average at the last closed flux surface. We use the ITER convention for normalized internal inductance defined as: li3 = 2*V*<Bpol^2>_V / (mu0^2 Ip^2*Rmaj) = 4 * Wpol / (mu0 Ip^2*Rmaj) Ip (total plasma current) enters through the integral form of Ampere's law. Since Wpol also corresponds to a volume integral of the poloidal field, we can define li3 with respect to Wpol. Args: Rmaj: Major radius. Wpol: Total magnetic energy. Ip_total: Total plasma current. Returns: li3: Normalized internal inductance, ITER convention. """ return 4 * Wpol / (constants.CONSTANTS.mu0 * Ip_total**2 * Rmaj)
[docs] def calc_q95( psi_norm_face: array_typing.ArrayFloat, q_face: array_typing.ArrayFloat, ) -> array_typing.ScalarFloat: """Calculates q95 from the q profile and the normalized poloidal flux. Args: psi_norm_face: normalized poloidal flux q_face: safety factor on the face grid Returns: q95: q at 95% of the normalized poloidal flux """ q95 = jnp.interp(0.95, psi_norm_face, q_face) return q95
[docs] def calculate_psi_grad_constraint_from_Ip_tot( Ip_tot: array_typing.ScalarFloat, geo: geometry.Geometry, ) -> jax.Array: """Calculates the gradient constraint on the poloidal flux (psi) from Ip.""" return ( Ip_tot * 1e6 * (16 * jnp.pi**3 * constants.CONSTANTS.mu0 * geo.Phib) / (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1]) )
[docs] def calculate_psidot_from_psi_sources( *, psi_sources: array_typing.ArrayFloat, sigma: array_typing.ArrayFloat, sigma_face: array_typing.ArrayFloat, resistivity_multiplier: float, psi: cell_variable.CellVariable, geo: geometry.Geometry, ) -> jax.Array: """Calculates psidot (loop voltage) from the sum of the psi sources.""" # Calculate transient term consts = constants.CONSTANTS toc_psi = ( 1.0 / resistivity_multiplier * geo.rho_norm * sigma * consts.mu0 * 16 * jnp.pi**2 * geo.Phib**2 / geo.F**2 ) # Calculate diffusion term coefficient d_face_psi = geo.g2g3_over_rhon_face # Add phibdot terms to poloidal flux convection v_face_psi = ( -8.0 * jnp.pi**2 * consts.mu0 * geo.Phibdot * geo.Phib * sigma_face * geo.rho_face_norm**2 / geo.F_face**2 ) # Add effective phibdot poloidal flux source term ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient( sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm ) psi_sources += ( -8.0 * jnp.pi**2 * consts.mu0 * geo.Phibdot * geo.Phib * ddrnorm_sigma_rnorm2_over_f2 ) diffusion_mat, diffusion_vec = diffusion_terms.make_diffusion_terms( d_face_psi, psi ) # Set the psi convection term for psidot used in ohmic power, always with # the default 'ghost' mode. Impact of different modes would mildly impact # Ohmic power at the LCFS which has negligible impact on simulations. # Allowing it to be configurable introduces more complexity in the code by # needing to pass in the mode from the static_runtime_params across multiple # functions. conv_mat, conv_vec = convection_terms.make_convection_terms( v_face_psi, d_face_psi, psi ) c_mat = diffusion_mat + conv_mat c = diffusion_vec + conv_vec c += psi_sources psidot = (jnp.dot(c_mat, psi.value) + c) / toc_psi return psidot