Running a completely JITted simulation
TORAX contains various helpers for running simulations completely under a
[jax.jit](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) context.
This can be useful for performance as well as making use of other
JAX functionality such as batching, automatic differentation and using
accelerators such as GPU or TPU.
JIT compatible version of run run_loop
Under the experimental API we provide a run_loop_jit function that can be
used to run a simulation in a JITted context.
from torax import experimental as torax_experimental
step_fn = torax_experimental.make_step_fn(torax_config)
# The simulation loop will exit after executing at most this many time steps.
# This is needed to provide a constant size graph for JAX to compile but also
# means that a simulation could be incomplete if it needs to run for more than
# the provided max_steps.
max_steps = 100
sim_states, post_processed_outputs, final_i = torax_experimental.run_loop_jit(
step_fn=step_fn,
max_steps=max_steps,
)
Simulation overrides
We also provide functionality for overriding runtime parameters for a simulation. Importantly these helpers are themselves JIT compatible so can be used as part of a larger JITted function involving a TORAX simulation.
The mechanism for providing overrides is via a RuntimeParamsProvider object.
We can call update_provider_from_mapping on this object with a mapping of
dot-separated parameters paths to override values and get a new provider with
the overridden values.
See examples/iter_hybrid_rampup_grad_and_vmap.ipynb for an example of how
to use this functionality and the docstrings of the methods below for more
details on usage.
# Replace the `TimeVaryingScalar` `Ip`.
ip_update = torax_experimental.TimeVaryingScalarReplace(
value=new_ip_value,
)
# Replace the `TimeVaryingArray` profile `T_e`.
T_e_update = torax_experimental.TimeVaryingArrayReplace(
cell_value=T_e_cell_value * 3.0,
rho_norm=old_T_e.grid.cell_centers,
)
new_provider = step_fn.runtime_params_provider.update_provider_from_mapping(
{
'profile_conditions.Ip': ip_update,
'profile_conditions.T_e': T_e_update,
'sources.ei_exchange.Qei_multiplier': 2.0,
}
)
sim_states, post_processed_outputs, final_i = torax_experimental.run_loop_jit(
step_fn=step_fn,
max_steps=max_steps,
runtime_params_overrides=new_provider,
)