# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A wrapper for QLKNN transport surrogate models."""
from collections.abc import Mapping
from typing import Final
from fusion_surrogates.qlknn import qlknn_model
import immutabledict
import jax
import jax.numpy as jnp
from torax import jax_utils
from torax.transport_model import base_qlknn_model
from torax.transport_model import qualikiz_based_transport_model
# Convert flux names from Qualikiz to TORAX.
_FLUX_NAME_MAP: Final[Mapping[str, str]] = immutabledict.immutabledict({
'efiITG': 'qi_itg',
'efeITG': 'qe_itg',
'pfeITG': 'pfe_itg',
'efeTEM': 'qe_tem',
'efiTEM': 'qi_tem',
'pfeTEM': 'pfe_tem',
'efeETG': 'qe_etg',
})
[docs]
class QLKNNModelWrapper(base_qlknn_model.BaseQLKNNModel):
"""A TORAX wrapper for a QLKNN Model from the fusion_surrogates library."""
def __init__(
self,
path: str,
name: str = '',
flux_name_map: Mapping[str, str] | None = None,
):
if flux_name_map is None:
flux_name_map = _FLUX_NAME_MAP
self._flux_name_map = flux_name_map
if path:
self._model = qlknn_model.QLKNNModel.load_model_from_path(path, name)
elif name:
self._model = qlknn_model.QLKNNModel.load_model_from_name(name)
else:
self._model = qlknn_model.QLKNNModel.load_default_model()
super().__init__(path=self._model.path, name=self._model.name)
@property
def inputs_and_ranges(self) -> base_qlknn_model.InputsAndRanges:
return self._model.inputs_and_ranges
[docs]
def predict(self, inputs: jax.Array) -> dict[str, jax.Array]:
"""Predicts the fluxes given the inputs."""
model_predictions = self._model.predict(inputs)
return {
self._flux_name_map.get(flux_name, flux_name): flux_value
for flux_name, flux_value in model_predictions.items()
}