Source code for torax.post_processing

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for adding post-processed outputs to the simulation state."""

import jax
from jax import numpy as jnp
from torax import constants
from torax import jax_utils
from torax import math_utils
from torax import state
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.physics import formulas
from torax.physics import psi_calculations
from torax.physics import scaling_laws
from torax.sources import source_profiles

_trapz = jax.scipy.integrate.trapezoid

# TODO(b/376010694): use the various SOURCE_NAMES for the keys.
ION_EL_HEAT_SOURCE_TRANSFORMATIONS = {
    'generic_ion_el_heat_source': 'P_generic',
    'fusion_heat_source': 'P_alpha',
    'ion_cyclotron_source': 'P_icrh',
}
EL_HEAT_SOURCE_TRANSFORMATIONS = {
    'ohmic_heat_source': 'P_ohmic',
    'bremsstrahlung_heat_sink': 'P_brems',
    'cyclotron_radiation_heat_sink': 'P_cycl',
    'electron_cyclotron_source': 'P_ecrh',
    'impurity_radiation_heat_sink': 'P_rad',
}
EXTERNAL_HEATING_SOURCES = [
    'generic_ion_el_heat_source',
    'electron_cyclotron_source',
    'ohmic_heat_source',
    'ion_cyclotron_source',
]
CURRENT_SOURCE_TRANSFORMATIONS = {
    'generic_current_source': 'I_generic',
    'electron_cyclotron_source': 'I_ecrh',
}


def _calculate_integrated_sources(
    geo: geometry.Geometry,
    core_profiles: state.CoreProfiles,
    core_sources: source_profiles.SourceProfiles,
    dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
) -> dict[str, jax.Array]:
  """Calculates total integrated internal and external source power and current.

  Args:
    geo: Magnetic geometry
    core_profiles: Kinetic profiles such as temperature and density
    core_sources: Internal and external sources
    dynamic_runtime_params_slice: Runtime parameters slice for current time step

  Returns:
    Dictionary with integrated quantities for all existing sources.
    See state.PostProcessedOutputs for the full list of keys.

    Output dict used to update the `PostProcessedOutputs` object in the
    output `ToraxSimState`. Sources that don't exist do not have integrated
    quantities included in the returned dict. The corresponding
    `PostProcessedOutputs` attributes remain at their initialized zero values.
  """
  integrated = {}

  # Initialize total alpha power to zero. Needed for Q calculation.
  integrated['P_alpha_tot'] = jnp.array(0.0, dtype=jax_utils.get_dtype())

  # electron-ion heat exchange always exists, and is not in
  # core_sources.profiles, so we calculate it here.
  qei = core_sources.qei.qei_coef * (
      core_profiles.temp_el.value - core_profiles.temp_ion.value
  )
  integrated['P_ei_exchange_ion'] = math_utils.volume_integration(qei, geo)
  integrated['P_ei_exchange_el'] = -integrated['P_ei_exchange_ion']

  # Initialize total electron and ion powers
  # TODO(b/380848256): P_sol is now correct for stationary state. However,
  # for generality need to add dWth/dt to the equation (time dependence of
  # stored energy).
  integrated['P_sol_ion'] = integrated['P_ei_exchange_ion']
  integrated['P_sol_el'] = integrated['P_ei_exchange_el']
  integrated['P_external_ion'] = jnp.array(0.0, dtype=jax_utils.get_dtype())
  integrated['P_external_el'] = jnp.array(0.0, dtype=jax_utils.get_dtype())
  integrated['P_external_injected'] = jnp.array(
      0.0, dtype=jax_utils.get_dtype()
  )

  # Calculate integrated sources with convenient names, transformed from
  # TORAX internal names.
  for key, value in ION_EL_HEAT_SOURCE_TRANSFORMATIONS.items():
    ion_profiles = core_sources.temp_ion
    el_profiles = core_sources.temp_el
    if key in ion_profiles and key in el_profiles:
      profile_ion, profile_el = ion_profiles[key], el_profiles[key]
      integrated[f'{value}_ion'] = math_utils.volume_integration(
          profile_ion, geo
      )
      integrated[f'{value}_el'] = math_utils.volume_integration(profile_el, geo)
      integrated[f'{value}_tot'] = (
          integrated[f'{value}_ion'] + integrated[f'{value}_el']
      )
      integrated['P_sol_ion'] += integrated[f'{value}_ion']
      integrated['P_sol_el'] += integrated[f'{value}_el']
      if key in EXTERNAL_HEATING_SOURCES:
        integrated['P_external_ion'] += integrated[f'{value}_ion']
        integrated['P_external_el'] += integrated[f'{value}_el']

        # Track injected power for heating sources that have absorption_fraction
        # These are only for sources like ICRH or NBI that are
        # ion_el_heat_sources.
        source_params = dynamic_runtime_params_slice.sources[key]
        if hasattr(source_params, 'absorption_fraction'):
          total_absorbed = integrated[f'{value}_tot']
          injected_power = total_absorbed / source_params.absorption_fraction
          integrated['P_external_injected'] += injected_power
        else:
          integrated['P_external_injected'] += integrated[f'{value}_tot']
    else:
      integrated[f'{value}_ion'] = jnp.array(0.0, dtype=jax_utils.get_dtype())
      integrated[f'{value}_el'] = jnp.array(0.0, dtype=jax_utils.get_dtype())
      integrated[f'{value}_tot'] = jnp.array(0.0, dtype=jax_utils.get_dtype())

  for key, value in EL_HEAT_SOURCE_TRANSFORMATIONS.items():
    profiles = core_sources.temp_el
    if key in profiles:
      integrated[f'{value}'] = math_utils.volume_integration(profiles[key], geo)
      integrated['P_sol_el'] += integrated[f'{value}']
      if key in EXTERNAL_HEATING_SOURCES:
        integrated['P_external_el'] += integrated[f'{value}']
        integrated['P_external_injected'] += integrated[f'{value}']
    else:
      integrated[f'{value}'] = jnp.array(0.0, dtype=jax_utils.get_dtype())

  for key, value in CURRENT_SOURCE_TRANSFORMATIONS.items():
    profiles = core_sources.psi
    if key in profiles:
      integrated[f'{value}'] = math_utils.area_integration(profiles[key], geo)
    else:
      integrated[f'{value}'] = jnp.array(0.0, dtype=jax_utils.get_dtype())

  integrated['P_sol_tot'] = integrated['P_sol_ion'] + integrated['P_sol_el']
  integrated['P_external_tot'] = (
      integrated['P_external_ion'] + integrated['P_external_el']
  )

  return integrated


[docs] @jax_utils.jit def make_post_processed_outputs( sim_state: state.ToraxSimState, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, previous_post_processed_outputs: state.PostProcessedOutputs | None = None, ) -> state.PostProcessedOutputs: """Calculates post-processed outputs based on the latest state. Called at the beginning and end of each `sim.run_simulation` step. Args: sim_state: The state to add outputs to. dynamic_runtime_params_slice: Runtime parameters slice for the current time step, needed for calculating integrated power. previous_post_processed_outputs: The previous outputs, used to calculate cumulative quantities. Optional input. If None, then cumulative quantities are set at the initialized values in sim_state itself. This is used for the first time step of a the simulation. The initialized values are zero for a clean simulation, or the last value of the previous simulation for a restarted simulation. Returns: post_processed_outputs: The post_processed_outputs for the given state. """ ( pressure_thermal_el_face, pressure_thermal_ion_face, pressure_thermal_tot_face, ) = formulas.calculate_pressure(sim_state.core_profiles) pprime_face = formulas.calc_pprime(sim_state.core_profiles) # pylint: disable=invalid-name W_thermal_el, W_thermal_ion, W_thermal_tot = ( formulas.calculate_stored_thermal_energy( pressure_thermal_el_face, pressure_thermal_ion_face, pressure_thermal_tot_face, sim_state.geometry, ) ) FFprime_face = formulas.calc_FFprime( sim_state.core_profiles, sim_state.geometry ) # Calculate normalized poloidal flux. psi_face = sim_state.core_profiles.psi.face_value() psi_norm_face = (psi_face - psi_face[0]) / (psi_face[-1] - psi_face[0]) integrated_sources = _calculate_integrated_sources( sim_state.geometry, sim_state.core_profiles, sim_state.core_sources, dynamic_runtime_params_slice, ) # Calculate fusion gain with a zero division guard. # Total energy released per reaction is 5 times the alpha particle energy. Q_fusion = ( integrated_sources['P_alpha_tot'] * 5.0 / (integrated_sources['P_external_injected'] + constants.CONSTANTS.eps) ) P_LH_hi_dens, P_LH_min, P_LH, ne_min_P_LH = ( scaling_laws.calculate_plh_scaling_factor( sim_state.geometry, sim_state.core_profiles ) ) # Thermal energy confinement time is the stored energy divided by the total # input power into the plasma. # Ploss term here does not include the reduction of radiated power. Most # analysis of confinement times from databases have not included this term. # Therefore highly radiative scenarios can lead to skewed results. Ploss = ( integrated_sources['P_alpha_tot'] + integrated_sources['P_external_tot'] ) if previous_post_processed_outputs is not None: dW_th_dt = ( W_thermal_tot - previous_post_processed_outputs.W_thermal_tot ) / sim_state.dt else: dW_th_dt = 0.0 tauE = W_thermal_tot / Ploss tauH89P = scaling_laws.calculate_scaling_law_confinement_time( sim_state.geometry, sim_state.core_profiles, Ploss, 'H89P' ) tauH98 = scaling_laws.calculate_scaling_law_confinement_time( sim_state.geometry, sim_state.core_profiles, Ploss, 'H98' ) tauH97L = scaling_laws.calculate_scaling_law_confinement_time( sim_state.geometry, sim_state.core_profiles, Ploss, 'H97L' ) tauH20 = scaling_laws.calculate_scaling_law_confinement_time( sim_state.geometry, sim_state.core_profiles, Ploss, 'H20' ) H89P = tauE / tauH89P H98 = tauE / tauH98 H97L = tauE / tauH97L H20 = tauE / tauH20 # Calculate total external (injected) and fusion (generated) energies based on # interval average. if previous_post_processed_outputs is not None: # Factor 5 due to including neutron energy: E_fusion = 5.0 * E_alpha E_cumulative_fusion = ( previous_post_processed_outputs.E_cumulative_fusion + 5.0 * sim_state.dt * ( integrated_sources['P_alpha_tot'] + previous_post_processed_outputs.P_alpha_tot ) / 2.0 ) E_cumulative_external = ( previous_post_processed_outputs.E_cumulative_external + sim_state.dt * ( integrated_sources['P_external_tot'] + previous_post_processed_outputs.P_external_tot ) / 2.0 ) else: # Used during initiailization. Note for restarted simulations this should # be overwritten by the previous_post_processed_outputs. E_cumulative_fusion = 0.0 E_cumulative_external = 0.0 # Calculate q at 95% of the normalized poloidal flux q95 = psi_calculations.calc_q95(psi_norm_face, sim_state.core_profiles.q_face) # Calculate te and ti volume average [keV] te_volume_avg = math_utils.volume_average( sim_state.core_profiles.temp_el.value, sim_state.geometry ) ti_volume_avg = math_utils.volume_average( sim_state.core_profiles.temp_ion.value, sim_state.geometry ) # Calculate ne and ni (main ion) volume and line averages [nref m^-3] ne_volume_avg = math_utils.volume_average( sim_state.core_profiles.ne.value, sim_state.geometry ) ni_volume_avg = math_utils.volume_average( sim_state.core_profiles.ni.value, sim_state.geometry ) ne_line_avg = math_utils.line_average( sim_state.core_profiles.ne.value, sim_state.geometry ) ni_line_avg = math_utils.line_average( sim_state.core_profiles.ni.value, sim_state.geometry ) fgw_ne_volume_avg = formulas.calculate_greenwald_fraction( ne_volume_avg, sim_state.core_profiles, sim_state.geometry ) fgw_ne_line_avg = formulas.calculate_greenwald_fraction( ne_line_avg, sim_state.core_profiles, sim_state.geometry ) Wpol = psi_calculations.calc_Wpol( sim_state.geometry, sim_state.core_profiles.psi ) li3 = psi_calculations.calc_li3( sim_state.geometry.Rmaj, Wpol, sim_state.core_profiles.currents.Ip_profile_face[-1], ) # pylint: enable=invalid-name return state.PostProcessedOutputs( pressure_thermal_ion_face=pressure_thermal_ion_face, pressure_thermal_el_face=pressure_thermal_el_face, pressure_thermal_tot_face=pressure_thermal_tot_face, pprime_face=pprime_face, W_thermal_ion=W_thermal_ion, W_thermal_el=W_thermal_el, W_thermal_tot=W_thermal_tot, tauE=tauE, H89P=H89P, H98=H98, H97L=H97L, H20=H20, FFprime_face=FFprime_face, psi_norm_face=psi_norm_face, psi_face=sim_state.core_profiles.psi.face_value(), **integrated_sources, Q_fusion=Q_fusion, P_LH=P_LH, P_LH_min=P_LH_min, P_LH_hi_dens=P_LH_hi_dens, ne_min_P_LH=ne_min_P_LH, E_cumulative_fusion=E_cumulative_fusion, E_cumulative_external=E_cumulative_external, te_volume_avg=te_volume_avg, ti_volume_avg=ti_volume_avg, ne_volume_avg=ne_volume_avg, ni_volume_avg=ni_volume_avg, ne_line_avg=ne_line_avg, ni_line_avg=ni_line_avg, fgw_ne_volume_avg=fgw_ne_volume_avg, fgw_ne_line_avg=fgw_ne_line_avg, q95=q95, Wpol=Wpol, li3=li3, dW_th_dt=dW_th_dt, )