Source code for torax.sources.tests.source_profiles_test

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.sources import source_profiles as source_profiles_lib


[docs] class SourceProfilesTest(parameterized.TestCase): def test_summed_temp_ion_profiles_dont_change_when_jitting(self): geo = geometry_pydantic_model.CircularConfig().build_geometry() # Make some dummy source profiles that could have come from these sources. ones = jnp.ones_like(geo.rho) profiles = source_profiles_lib.SourceProfiles( j_bootstrap=source_profiles_lib.BootstrapCurrentProfile.zero_profile( geo ), qei=source_profiles_lib.QeiInfo.zeros(geo), temp_ion={ 'generic_ion_el_heat_source': ones, 'fusion_heat_source': ones * 3, }, temp_el={ 'generic_ion_el_heat_source': ones * 2, 'fusion_heat_source': ones * 4, 'bremsstrahlung_heat_sink': -ones, 'ohmic_heat_source': ones * 5, }, ne={}, psi={}, ) with self.subTest('without_jit'): summed_temp_ion = profiles.total_sources('temp_ion', geo) np.testing.assert_allclose(summed_temp_ion, ones * 4 * geo.vpr) summed_temp_el = profiles.total_sources('temp_el', geo) np.testing.assert_allclose(summed_temp_el, ones * 10 * geo.vpr) with self.subTest('with_jit'): sum_temp = jax.jit( profiles.total_sources, static_argnames=('source_type') ) jitted_temp_ion = sum_temp('temp_ion', geo) np.testing.assert_allclose(jitted_temp_ion, ones * 4 * geo.vpr) jitted_temp_el = sum_temp('temp_el', geo) np.testing.assert_allclose(jitted_temp_el, ones * 10 * geo.vpr)
if __name__ == '__main__': absltest.main()