# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Commonly repeated jax expressions."""
import contextlib
import functools
import os
from typing import Any, Callable, Optional, TypeVar
import chex
import equinox as eqx
import jax
from jax import numpy as jnp
import numpy as np
T = TypeVar('T')
BooleanNumeric = Any # A bool, or a Boolean array.
@functools.cache
def get_dtype() -> type(jnp.float32):
# Default TORAX JAX precision is f64
precision = os.getenv('JAX_PRECISION', 'f64')
assert precision == 'f64' or precision == 'f32', (
'Unknown JAX precision environment variable: %s' % precision
)
return jnp.float64 if precision == 'f64' else jnp.float32
@functools.cache
def get_np_dtype() -> type(np.float32):
# Default TORAX JAX precision is f64
precision = os.getenv('JAX_PRECISION', 'f64')
assert precision == 'f64' or precision == 'f32', (
'Unknown JAX precision environment variable: %s' % precision
)
return np.float64 if precision == 'f64' else np.float32
@functools.cache
def get_int_dtype() -> type(jnp.int32):
# Default TORAX JAX precision is f64
precision = os.getenv('JAX_PRECISION', 'f64')
assert precision == 'f64' or precision == 'f32', (
'Unknown JAX precision environment variable: %s' % precision
)
return jnp.int64 if precision == 'f64' else jnp.int32
[docs]
def env_bool(name: str, default: bool) -> bool:
"""Get a bool from an environment variable.
Args:
name: The name of the environment variable.
default: The default value of the bool.
Returns:
value: The value of the bool.
"""
if name not in os.environ:
return default
str_value = os.environ[name]
if str_value in ['1', 'True', 'true']:
return True
if str_value in ['0', 'False', 'false']:
return False
raise ValueError(f'Unrecognized boolean string {str_value}.')
# If True, `error_if` functions will raise errors. Otherwise they are
# pass throughs.
# Default to False, because host_callbacks are incompatible with the
# persistent compilation cache.
_ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', False)
[docs]
@contextlib.contextmanager
def enable_errors(value: bool):
"""Enables / disables `error_if` inside a code block.
Example:
with enable_errors(False):
my_sim.run() # NaNs etc will be ignored
Args:
value: Sets `errors_enabled` to this value
Yields:
Cleanup function restoring previous value
"""
global _ERRORS_ENABLED
previous_value = _ERRORS_ENABLED
_ERRORS_ENABLED = value
yield
if previous_value is not None:
_ERRORS_ENABLED = previous_value
[docs]
def error_if(
var: jax.Array,
cond: jax.Array,
msg: str,
) -> jax.Array:
"""Raises error if cond is true, and `errors_enabled` is True.
This is just a wrapper around `equinox.error_if`, gated by `errors_enabled`.
Args:
var: The variable to pass through.
cond: Boolean array, error if cond is true.
msg: Message to print on error.
Returns:
var: Identity wrapper that must be used for the check to be included.
"""
if not _ERRORS_ENABLED:
return var
return eqx.error_if(var, cond, msg)
[docs]
def error_if_negative(
var: jax.Array, name: str, to_wrap: Optional[jax.Array] = None
) -> jax.Array:
"""Check that a variable is non-negative.
Similar to error_if_not_positive, but 0 is allowed in this function.
Args:
var: The variable to check.
name: Name of the variable.
to_wrap: If `var` won't be used in your jax function, specify another
variable that will be.
Returns:
var: Identity wrapper that must be used for the check to be included.
"""
msg = f'{name} must be >= 0.'
min_var = jnp.min(var)
if to_wrap is None:
to_wrap = var
return error_if(to_wrap, min_var < 0, msg)
[docs]
def assert_rank(
inputs: chex.Numeric | jax.stages.ArgInfo,
rank: int,
) -> None:
"""Wrapper around chex.assert_rank that supports jax.stages.ArgInfo."""
if isinstance(inputs, jax.stages.ArgInfo):
chex.assert_rank(inputs.shape, rank)
else:
chex.assert_rank(inputs, rank)
[docs]
def jit(*args, **kwargs) -> Callable[..., Any]:
"""Calls jax.jit if TORAX_COMPILATION_ENABLED is True, otherwise no-op."""
if env_bool('TORAX_COMPILATION_ENABLED', True):
return jax.jit(*args, **kwargs)
return args[0]
[docs]
def py_while(
cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T,
) -> T:
"""Pure Python implementation of jax.lax.while_loop.
This gives us a way to write code that could easily be changed to be
Jax-compatible in the future (if we want to compute its gradient or
compile it, etc.) without having to pay the high compile time cost
of jax.lax.while_loop.
Args:
cond_fun: function of type ``a -> Bool``.
body_fun: function of type ``a -> a``.
init_val: value of type ``a``, a type that can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof, representing the initial
loop carry value.
Returns:
The output from the final iteration of body_fun, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
[docs]
def py_fori_loop(
lower: int, upper: int, body_fun: Callable[[int, T], T], init_val: T
) -> T:
"""Pure Python implementation of jax.lax.fori_loop.
This gives us a way to write code that could easily be changed to be
Jax-compatible in the future, if we want to expand the scope of the jit
compilation.
Args:
lower: lower integer of loop
upper: upper integer of loop. upper<=lower will produce no iterations.
body_fun: function of type ``a -> a``.
init_val: value of type ``a``, a type that can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof, representing the initial
loop carry value.
Returns:
The output from the final iteration of body_fun, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
# pylint: disable=g-bare-generic
[docs]
def py_cond(
cond: bool,
true_fun: Callable,
false_fun: Callable,
) -> Any:
"""Pure Python implementation of jax.lax.cond.
This gives us a way to write code that could easily be changed to be
Jax-compatible in the future, if we want to expand the scope of the jit
compilation.
Args:
cond: The condition
true_fun: Function to be called if cond==True.
false_fun: Function to be called if cond==False.
Returns:
The output from either true_fun or false_fun.
"""
if cond:
return true_fun()
else:
return false_fun()
[docs]
def get_number_of_compiles(
jitted_function: Callable[..., Any],
) -> int:
"""Helper function for debugging JAX compilation.
This counts the number of times the function has been JIT compiled. This does
not include any uses of the AOT compile workflow.
Args:
jitted_function: A function that has been wrapped with `jax.jit`.
Returns:
The number of times the function has been compiled.
Raises:
RuntimeError: If the function does not have a _cache_size attribute.
"""
# pylint: disable=protected-access
if not hasattr(jitted_function, '_cache_size'):
raise RuntimeError(
'The function does not have a _cache_size attribute. Possibly because'
' the function was not jitted.'
)
return jitted_function._cache_size()
# pylint: enable=protected-access
# pylint: enable=g-bare-generic