# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax.sources import base
from torax.sources import bootstrap_current_source
from torax.sources import fusion_heat_source
from torax.sources import gas_puff_source
from torax.sources import generic_current_source
from torax.sources import pydantic_model
from torax.sources import qei_source
from torax.sources import runtime_params as source_runtime_params_lib
from torax.sources import source_models as source_models_lib
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_constant_fraction
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_mavrin_fit
from torax.torax_pydantic import torax_pydantic
[docs]
class PydanticModelTest(parameterized.TestCase):
@parameterized.parameters(
dict(
config={
'gas_puff_source': {
'puff_decay_length': 0.3,
'S_puff_tot': 0.0,
}
},
expected_sources_model=gas_puff_source.GasPuffSourceConfig,
),
dict(
config={
'j_bootstrap': {
'bootstrap_mult': 0.3,
}
},
expected_sources_model=bootstrap_current_source.BootstrapCurrentSourceConfig,
),
dict(
config={
'fusion_heat_source': {},
},
expected_sources_model=fusion_heat_source.FusionHeatSourceConfig,
),
dict(
config={
'impurity_radiation_heat_sink': {
'model_function_name': 'impurity_radiation_mavrin_fit'
},
},
expected_sources_model=impurity_radiation_mavrin_fit.ImpurityRadiationHeatSinkMavrinFitConfig,
),
dict(
config={
'impurity_radiation_heat_sink': {
'model_function_name': 'radially_constant_fraction_of_Pin'
},
},
expected_sources_model=impurity_radiation_constant_fraction.ImpurityRadiationHeatSinkConstantFractionConfig,
),
)
def test_correct_source_model(
self,
config: dict[str, Any],
expected_sources_model: type[base.SourceModelBase],
):
sources_model = pydantic_model.Sources.from_dict(config)
self.assertIsInstance(
sources_model.source_model_config[list(config.keys())[0]],
expected_sources_model,
)
# Check that the 3 default sources are always present.
for key in [
bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME,
qei_source.QeiSource.SOURCE_NAME,
generic_current_source.GenericCurrentSource.SOURCE_NAME,
]:
self.assertIn(key, sources_model.source_model_config.keys())
[docs]
def test_adding_standard_source_via_config(self):
"""Tests that a source can be added with overriding defaults."""
sources = pydantic_model.Sources.from_dict({
'gas_puff_source': {
'puff_decay_length': 1.23,
},
'ohmic_heat_source': {
'is_explicit': True,
'mode': 'ZERO', # turn it off.
},
})
source_models = source_models_lib.SourceModels(sources.source_model_config)
# The non-standard ones are still off.
self.assertEqual(
sources.source_model_config['j_bootstrap'].mode,
source_runtime_params_lib.Mode.ZERO,
)
self.assertEqual(
sources.source_model_config['generic_current_source'].mode,
source_runtime_params_lib.Mode.ZERO,
)
self.assertEqual(
sources.source_model_config['qei_source'].mode,
source_runtime_params_lib.Mode.ZERO,
)
# But these new sources have been added.
self.assertLen(source_models.sources, 5)
self.assertLen(source_models.standard_sources, 3)
# With the overriding params.
gas_puff_config = sources.source_model_config['gas_puff_source']
self.assertIsInstance(gas_puff_config, gas_puff_source.GasPuffSourceConfig)
self.assertEqual(
gas_puff_config.puff_decay_length.get_value(0.0),
1.23,
)
self.assertEqual(
sources.source_model_config['gas_puff_source'].mode,
source_runtime_params_lib.Mode.MODEL_BASED, # On by default.
)
self.assertEqual(
sources.source_model_config['ohmic_heat_source'].mode,
source_runtime_params_lib.Mode.ZERO,
)
[docs]
def test_empty_source_config_only_has_defaults_turned_off(self):
"""Tests that an empty source config has all sources turned off."""
sources = pydantic_model.Sources.from_dict({})
self.assertEqual(
sources.source_model_config['j_bootstrap'].mode,
source_runtime_params_lib.Mode.ZERO,
)
self.assertEqual(
sources.source_model_config['generic_current_source'].mode,
source_runtime_params_lib.Mode.ZERO,
)
self.assertEqual(
sources.source_model_config['qei_source'].mode,
source_runtime_params_lib.Mode.ZERO,
)
self.assertLen(sources.source_model_config, 3)
[docs]
def test_adding_a_source_with_prescribed_values(self):
"""Tests that a source can be added with overriding defaults."""
sources = pydantic_model.Sources.from_dict({
'generic_current_source': {
'mode': 'PRESCRIBED',
'prescribed_values': ((
np.array([0.0, 1.0, 2.0, 3.0]),
np.array([0., 0.5, 1.0]),
np.full([4, 3], 42)
),),
},
'electron_cyclotron_source': {
'mode': 'PRESCRIBED',
'prescribed_values': (
3.,
4.,
),
}
})
mesh = torax_pydantic.Grid1D(nx=4, dx=0.25)
torax_pydantic.set_grid(sources, mesh)
source = sources.source_model_config['generic_current_source']
self.assertLen(source.prescribed_values, 1)
self.assertIsInstance(
source.prescribed_values[0], torax_pydantic.TimeVaryingArray)
source = sources.source_model_config['electron_cyclotron_source']
self.assertLen(source.prescribed_values, 2)
self.assertIsInstance(
source.prescribed_values[0], torax_pydantic.TimeVaryingArray)
self.assertIsInstance(
source.prescribed_values[1], torax_pydantic.TimeVaryingArray)
value = source.prescribed_values[0].get_value(0.0)
np.testing.assert_equal(value, 3.)
value = source.prescribed_values[1].get_value(0.0)
np.testing.assert_equal(value, 4.)
def test_bremsstrahlung_and_mavrin_validator_with_bremsstrahlung_zero(self):
valid_config = {
'bremsstrahlung_heat_sink': {'mode': 'ZERO'},
'impurity_radiation_heat_sink': {
'mode': 'PRESCRIBED',
'model_function_name': 'impurity_radiation_mavrin_fit',
},
}
pydantic_model.Sources.from_dict(valid_config)
def test_bremsstrahlung_and_mavrin_validator_with_mavrin_zero(self):
valid_config = {
'bremsstrahlung_heat_sink': {'mode': 'PRESCRIBED'},
'impurity_radiation_heat_sink': {
'mode': 'ZERO',
'model_function_name': 'impurity_radiation_mavrin_fit',
},
}
pydantic_model.Sources.from_dict(valid_config)
def test_bremsstrahlung_and_mavrin_validator_with_constant_fraction(self):
valid_config = {
'bremsstrahlung_heat_sink': {'mode': 'PRESCRIBED'},
'impurity_radiation_heat_sink': {
'mode': 'PRESCRIBED',
'model_function_name': 'radially_constant_fraction_of_Pin',
},
}
pydantic_model.Sources.from_dict(valid_config)
def test_bremsstrahlung_and_mavrin_validator_with_invalid_config(self):
invalid_config = {
'bremsstrahlung_heat_sink': {'mode': 'PRESCRIBED'},
'impurity_radiation_heat_sink': {
'mode': 'PRESCRIBED',
'model_function_name': 'impurity_radiation_mavrin_fit',
},
}
with self.assertRaisesRegex(
ValueError,
'Both bremsstrahlung_heat_sink and impurity_radiation_heat_sink',
):
pydantic_model.Sources.from_dict(invalid_config)
if __name__ == '__main__':
absltest.main()