Source code for torax.transport_model.tests.qlknn_model_wrapper_test

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from fusion_surrogates.qlknn import qlknn_model
from fusion_surrogates.qlknn import qlknn_model_test_utils
import jax.numpy as jnp
from torax.transport_model import qlknn_model_wrapper


[docs] class QlknnModelWrapperTest(parameterized.TestCase):
[docs] def setUp(self): super().setUp() if qlknn_model_test_utils is None: self.skipTest('fusion_surrogates is not available.') # Create a test model on disk to be loaded by the wrapper. self._config = qlknn_model_test_utils.get_test_model_config() self._batch_dim = 10 batch_dims = (1, self._batch_dim) model = qlknn_model_test_utils.init_model(self._config, batch_dims) self._model_file = tempfile.NamedTemporaryFile( 'wb', suffix='.pkl', delete=False ) self._flux_name_map = dict( (flux_name, f'torax_{flux_name}') for flux_name in qlknn_model_test_utils.get_test_flux_map().keys() ) model.export_model(self._model_file.name) self._qlknn_model_wrapper = qlknn_model_wrapper.QLKNNModelWrapper( path=self._model_file.name, name='', flux_name_map=self._flux_name_map )
[docs] def test_predict_shape(self): """Tests model output shape.""" inputs = jnp.empty((self._batch_dim, len(self._config.input_names))) outputs = self._qlknn_model_wrapper.predict(inputs) self.assertLen(outputs, len(self._flux_name_map)) for output in outputs.values(): self.assertEqual(output.shape, (self._batch_dim, 1))
[docs] def test_predict_names(self): """Tests model output names are the TORAX flux names.""" inputs = jnp.empty((self._batch_dim, len(self._config.input_names))) outputs = self._qlknn_model_wrapper.predict(inputs) for flux_name in self._flux_name_map.values(): self.assertIn(flux_name, outputs)
[docs] @mock.patch.object( qlknn_model.QLKNNModel, 'load_model_from_path', autospec=True ) def test_load_model_from_path(self, mock_load_model_from_path): """Tests that the model is loaded from the path.""" qlknn_model_wrapper.QLKNNModelWrapper(path='/my/foo.qlknn', name='bar') mock_load_model_from_path.assert_called_once_with('/my/foo.qlknn', 'bar')
[docs] @mock.patch.object( qlknn_model.QLKNNModel, 'load_model_from_name', autospec=True ) def test_load_model_from_name(self, mock_load_model_from_name): """Tests that the model is loaded from the name.""" qlknn_model_wrapper.QLKNNModelWrapper(path='', name='bar') mock_load_model_from_name.assert_called_once_with('bar')
[docs] @mock.patch.object( qlknn_model.QLKNNModel, 'load_default_model', autospec=True ) def test_load_default_model(self, mock_load_default_model): """Tests that the default model is loaded.""" qlknn_model_wrapper.QLKNNModelWrapper(path='', name='') mock_load_default_model.assert_called_once()
# TODO(b/381134347): Add tests for get_model_inputs_from_qualikiz_inputs # and inputs_and_ranges. if __name__ == '__main__': absltest.main()