# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.sources import electron_cyclotron_source
from torax.sources import generic_current_source
from torax.sources import runtime_params as runtime_params_lib
[docs]
class SourceTest(parameterized.TestCase):
"""Tests for the base class Source."""
[docs]
def test_zero_profile_works_by_default(self):
"""The default source impl should support profiles with all zeros."""
source = generic_current_source.GenericCurrentSource()
geo = mock.create_autospec(geometry.Geometry,
rho_norm=np.array([1, 1, 1, 1]))
dynamic_source_params = {
generic_current_source.GenericCurrentSource.SOURCE_NAME: (
runtime_params_lib.DynamicRuntimeParams(
prescribed_values=np.zeros_like(geo.rho_norm),
)
)
}
static_source_params = {
generic_current_source.GenericCurrentSource.SOURCE_NAME: (
runtime_params_lib.StaticRuntimeParams(
mode=runtime_params_lib.Mode.ZERO.value,
is_explicit=False,
)
)
}
static_slice = mock.create_autospec(
runtime_params_slice.StaticRuntimeParamsSlice,
sources=static_source_params,
)
dynamic_slice = mock.create_autospec(
runtime_params_slice.DynamicRuntimeParamsSlice,
sources=dynamic_source_params,
)
profile = source.get_value(
dynamic_runtime_params_slice=dynamic_slice,
static_runtime_params_slice=static_slice,
geo=geo,
core_profiles=mock.ANY,
calculated_source_profiles=None,
)
np.testing.assert_allclose(profile[0], np.zeros_like(geo.rho_norm))
@parameterized.parameters(
(runtime_params_lib.Mode.ZERO, np.array([0, 0, 0, 0])),
(
runtime_params_lib.Mode.MODEL_BASED,
np.array([42, 42, 42, 42]),
),
(runtime_params_lib.Mode.PRESCRIBED, np.array([3, 3, 3, 3])),
)
def test_correct_mode_called(
self,
mode,
expected_profile,
):
model_func = mock.MagicMock()
model_func.return_value = np.full([4], 42.)
source = generic_current_source.GenericCurrentSource(model_func=model_func)
dynamic_source_params = {
generic_current_source.GenericCurrentSource.SOURCE_NAME: (
runtime_params_lib.DynamicRuntimeParams(
prescribed_values=(np.full([4], 3.),),
)
)
}
static_source_params = {
generic_current_source.GenericCurrentSource.SOURCE_NAME: (
runtime_params_lib.StaticRuntimeParams(
mode=mode.value,
is_explicit=False,
)
)
}
static_slice = mock.create_autospec(
runtime_params_slice.StaticRuntimeParamsSlice,
sources=static_source_params,
)
dynamic_slice = mock.create_autospec(
runtime_params_slice.DynamicRuntimeParamsSlice,
sources=dynamic_source_params,
)
# Make a geo with rho_norm as we need it for the zero profile shape.
geo = mock.create_autospec(geometry.Geometry,
rho_norm=np.array([1, 1, 1, 1]))
profile = source.get_value(
dynamic_runtime_params_slice=dynamic_slice,
static_runtime_params_slice=static_slice,
geo=geo,
core_profiles=mock.ANY,
calculated_source_profiles=None,
)
np.testing.assert_allclose(
profile[0],
expected_profile,
atol=1e-6,
rtol=1e-6,
)
def test_prescribed_values_for_multiple_affected_profiles(self):
source = electron_cyclotron_source.ElectronCyclotronSource()
dynamic_source_params = {
electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: (
runtime_params_lib.DynamicRuntimeParams(
prescribed_values=(np.full([4], 3.), np.full([4], 4.)),
)
)
}
static_source_params = {
electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: (
runtime_params_lib.StaticRuntimeParams(
mode=runtime_params_lib.Mode.PRESCRIBED.value,
is_explicit=False,
)
)
}
static_slice = mock.create_autospec(
runtime_params_slice.StaticRuntimeParamsSlice,
sources=static_source_params,
)
dynamic_slice = mock.create_autospec(
runtime_params_slice.DynamicRuntimeParamsSlice,
sources=dynamic_source_params,
)
profile = source.get_value(
dynamic_runtime_params_slice=dynamic_slice,
static_runtime_params_slice=static_slice,
geo=mock.ANY,
core_profiles=mock.ANY,
calculated_source_profiles=None,
)
self.assertLen(profile, 2)
np.testing.assert_allclose(
profile[0],
np.full([4], 3.),
atol=1e-6,
rtol=1e-6,
)
np.testing.assert_allclose(
profile[1],
np.full([4], 4.),
atol=1e-6,
rtol=1e-6,
)
def test_source_with_mismatched_prescribed_values_raises_error(self):
source = electron_cyclotron_source.ElectronCyclotronSource()
dynamic_source_params = {
electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME: (
runtime_params_lib.DynamicRuntimeParams(
prescribed_values=(np.full([4], 3.),),
)
)
}
static_source_params = {
electron_cyclotron_source.ElectronCyclotronSource.SOURCE_NAME:
runtime_params_lib.StaticRuntimeParams(
mode=runtime_params_lib.Mode.PRESCRIBED.value,
is_explicit=False,
)
}
static_slice = mock.create_autospec(
runtime_params_slice.StaticRuntimeParamsSlice,
sources=static_source_params,
)
dynamic_slice = mock.create_autospec(
runtime_params_slice.DynamicRuntimeParamsSlice,
sources=dynamic_source_params,
)
with self.assertRaisesRegex(
ValueError,
'the number of prescribed values must match the number of affected',
):
source.get_value(
dynamic_runtime_params_slice=dynamic_slice,
static_runtime_params_slice=static_slice,
geo=mock.ANY,
core_profiles=mock.ANY,
calculated_source_profiles=None,
)
if __name__ == '__main__':
absltest.main()