TORAX: Tokamak transport simulation in JAX
TORAX is a differentiable tokamak core transport simulator aimed for fast and accurate forward modelling, pulse-design, trajectory optimization, and controller design workflows. TORAX is written in Python using the JAX library.
Python facilitates coupling within various workflows and to additional physics models. Easy to install and JAX can seamlessly execute on multiple backends including CPU and GPU.
JAX provides just-in-time compilation for fast runtimes. JAX auto-differentiability enables gradient-based nonlinear PDE solvers and simulation sensitivity analysis while avoiding the need to manually derive Jacobians.
ML-surrogate coupling for fast and accurate simulation is greatly facilitated by JAX’s inherent support for neural network development and inference.