Source code for torax.sources.tests.cyclotron_radiation_heat_sink_test

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
from torax.sources import cyclotron_radiation_heat_sink


[docs] class CyclotronRadiationHeatSinkTest(parameterized.TestCase): """Unit tests for CyclotronRadiationHeatSink.""" @parameterized.product( alpha_expected=[0.5, 1.5, 2.5, 3.5, 4.5, 5.5], beta=[0.5, 1.5, 2.5, 3.5, 4.5, 5.5], profile_edge_value=[0.3, 1.0], profile_core_multiplier=[5.0, 10.0], ) def test_alpha_closed_form( self, alpha_expected, beta, profile_edge_value, profile_core_multiplier, ): """Test _alpha_closed_form in cyclotron_radiation_heat_sink.""" rho_norm = jnp.linspace(0.0, 1.0, 25) profile_data = ( profile_core_multiplier * (1 - rho_norm**beta) ** alpha_expected + profile_edge_value ) alpha_closed_form_jitted = jax.jit( cyclotron_radiation_heat_sink._alpha_closed_form ) # Calculate alpha with closed form formula. alpha = alpha_closed_form_jitted( beta=beta, rho_norm=rho_norm, profile_data=profile_data, profile_edge_value=profile_edge_value, ) # Check that alpha is as expected np.testing.assert_allclose(alpha, alpha_expected, atol=1e-3) @parameterized.product( alpha=[0.5, 1.5, 2.5, 3.5, 4.5, 5.5], beta=[0.5, 1.5, 2.5, 3.5, 4.5, 5.5], profile_edge_value=[0.3, 1.0], profile_core_multiplier=[5.0, 10.0], beta_scan_parameters=[(0.5, 8, 32), (0.4, 6, 16)], ) def test_solve_alpha_t_beta_t_grid_search( self, alpha, beta, profile_edge_value, profile_core_multiplier, beta_scan_parameters, ): """Test _solve_alpha_t_beta_t_grid_search in cyclotron_radiation_heat_sink.""" beta_trials = jnp.linspace( beta_scan_parameters[0], beta_scan_parameters[1], beta_scan_parameters[2], ) beta_expected = beta_trials[jnp.argmin(jnp.abs(beta_trials - beta))] rho_norm = jnp.linspace(0.0, 1.0, 25) profile_data = ( profile_core_multiplier * (1 - rho_norm**beta) ** alpha + profile_edge_value ) solve_alpha_t_beta_t_grid_search_jitted = jax.jit( cyclotron_radiation_heat_sink._solve_alpha_t_beta_t_grid_search, static_argnames=["beta_scan_parameters"], ) # Calculate alpha with closed form formula. _, beta_grid_search = solve_alpha_t_beta_t_grid_search_jitted( rho_norm=rho_norm, te_data=profile_data, beta_scan_parameters=beta_scan_parameters, ) # Check that beta is as expected np.testing.assert_allclose(beta_grid_search, beta_expected, atol=1e-7)
if __name__ == "__main__": absltest.main()