Source code for torax.config.config_loader

# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions to load a config from config file or directory."""

import importlib
import logging
import pathlib
import sys
import types
import typing
from typing import Any, Literal, TypeAlias
from torax.torax_pydantic import model_config

# Tracks all the modules imported so far. Maps the name to the module object.
_ALL_MODULES = {}

ExampleConfig: TypeAlias = Literal[
    'basic_config', 'iterhybrid_predictor_corrector', 'iterhybrid_rampup'
]


[docs] def build_torax_config_from_config_module( config_module_str: str, config_package: str | None = None, ) -> model_config.ToraxConfig: """Returns a Sim and RuntimeParams from the config module. Args: config_module_str: Python package path to config module. E.g. torax.examples.iterhybrid_predictor_corrector. config_package: Optional, base package config is imported from. See config_package flag docs. """ config_module = import_module(config_module_str, config_package) if hasattr(config_module, 'CONFIG'): # The module likely uses the "basic" config setup which has a single CONFIG # dictionary defining the full simulation. config = config_module.CONFIG torax_config = model_config.ToraxConfig.from_dict(config) else: raise ValueError( f'Config module {config_module_str} must define a CONFIG dictionary.' ) return torax_config
[docs] def import_module(module_name: str, config_package: str | None = None): """Imports a module.""" try: if module_name in _ALL_MODULES: return importlib.reload(_ALL_MODULES[module_name]) else: module = importlib.import_module(module_name, config_package) _ALL_MODULES[module_name] = module return module except Exception as e: logging.info('Exception raised: %s', e) raise ValueError('Exception while importing.') from e
[docs] def example_config_paths() -> dict[ExampleConfig, pathlib.Path]: """Returns a tuple of example config paths.""" example_dir = pathlib.Path(__file__).parent.parent.joinpath('examples') def _get_path(path): path = example_dir.joinpath(path + '.py') assert path.is_file(), f'Path {path} to the example config does not exist.' return path return {path: _get_path(path) for path in typing.get_args(ExampleConfig)}
# Taken from # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly def _import_from_path(module_name: str, file_path: str) -> types.ModuleType: spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module if module is None: raise ValueError(f'No loader found for module {module_name}.') else: spec.loader.exec_module(module) # pytype: disable=attribute-error return module
[docs] def import_config_dict(path: str | pathlib.Path) -> dict[str, Any]: """Import a Torax config dictionary from a file. Args: path: The path to the config file. The path can be represented as a string or a `pathlib.Path` object. Returns: The config dictionary. """ path = pathlib.Path(path) if isinstance(path, str) else path if not path.is_file(): raise ValueError(f'Path {path} is not a file.') arbitrary_module_name = '_torax_temp_config_import' module = _import_from_path(arbitrary_module_name, path) if not hasattr(module, 'CONFIG'): raise ValueError( f'The file {str(path)} is an invalid Torax config file, as it does not' ' have a `CONFIG` variable defined.' ) return module.CONFIG