Source code for torax.stepper.tests.pydantic_model_test

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
from torax.sources import pydantic_model as source_pydantic_model
from torax.sources import source_models as source_models_lib
from torax.stepper import linear_theta_method
from torax.stepper import nonlinear_theta_method
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.transport_model import pydantic_model as transport_pydantic_model


[docs] class PydanticModelTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name='linear', stepper_model=stepper_pydantic_model.LinearThetaMethod, expected_type=linear_theta_method.LinearThetaMethod, ), dict( testcase_name='newton_raphson', stepper_model=stepper_pydantic_model.NewtonRaphsonThetaMethod, expected_type=nonlinear_theta_method.NewtonRaphsonThetaMethod, ), dict( testcase_name='optimizer', stepper_model=stepper_pydantic_model.OptimizerThetaMethod, expected_type=nonlinear_theta_method.OptimizerThetaMethod, ), ) def test_build_stepper_from_config(self, stepper_model, expected_type): """Builds a stepper from the config.""" stepper = stepper_model(theta_imp=0.5) transport = transport_pydantic_model.ConstantTransportModel() transport_model = transport.build_transport_model() pedestal = pedestal_pydantic_model.NoPedestal() pedestal_model = pedestal.build_pedestal_model() sources = source_pydantic_model.Sources.from_dict({}) source_models = source_models_lib.SourceModels( sources=sources.source_model_config ) stepper_model = stepper.build_stepper( transport_model=transport_model, source_models=source_models, pedestal_model=pedestal_model, ) self.assertIsInstance(stepper_model, expected_type) self.assertEqual(stepper.theta_imp, 0.5)
if __name__ == '__main__': absltest.main()