Learning Methodologies for Autoregressive Neural Emulators.
PyPI Tests docs-latest Changelog License
Installation • Documentation • Quickstart • Background • Features • Citation
Convenience abstractions using optax to train neural networks to
autoregressively emulate time-dependent problems taking care of trajectory
subsampling and offering a wide range of training methodologies (regarding
unrolling length and including differentiable physics).
pip install trainax
Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.
The documentation is available at fkoehler.site/trainax.
Train a kernel size 2 linear convolution (no bias) to become an emulator for the 1D advection problem.
import jax import jax.numpy as jnp import equinox as eqx import optax # pip install optax import trainax as tx CFL = -0.75 ref_data = tx.sample_data.advection_1d_periodic( cfl = CFL, key = jax.random.PRNGKey(0), ) linear_conv_kernel_2 = eqx.nn.Conv1d( 1, 1, 2, padding="SAME", padding_mode="CIRCULAR", use_bias=False, key=jax.random.PRNGKey(73) ) sup_1_trainer, sup_5_trainer, sup_20_trainer = ( tx.trainer.SupervisedTrainer( ref_data, num_rollout_steps=r, optimizer=optax.adam(1e-2), num_training_steps=1000, batch_size=32, ) for r in (1, 5, 20) ) sup_1_conv, sup_1_loss_history = sup_1_trainer( linear_conv_kernel_2, key=jax.random.PRNGKey(42) ) sup_5_conv, sup_5_loss_history = sup_5_trainer( linear_conv_kernel_2, key=jax.random.PRNGKey(42) ) sup_20_conv, sup_20_loss_history = sup_20_trainer( linear_conv_kernel_2, key=jax.random.PRNGKey(42) ) FOU_STENCIL = jnp.array([1+CFL, -CFL]) print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL)) # 0.033 print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL)) # 0.025 print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL)) # 0.017
Increasing the supervised unrolling steps during training makes the learned stencil come closer to the numerical FOU stencil.
After the discretization of space and time, the simulation of a time-dependent partial differential equation amounts to the repeated application of a simulation operator $\mathcal{P}h$. Here, we are interested in imitating/emulating this physical/numerical operator with a neural network $f\theta$. This repository is concerned with an abstract implementation of all ways we can frame a learning problem to inject "knowledge" from $\mathcal{P}h$ into $f\theta$.
Assume we have a distribution of initial conditions
For a one-step supervised learning task, we substack the training trajectory
into windows of size
where
Trainax supports way more than just one-step supervised learning, e.g., to
train with unrolled steps, to include the reference simulator
- Wide collection of unrolled training methodologies:
- Supervised
- Diverted Chain
- Mix Chain
- Residuum
- Based on JAX:
- One of the best Automatic Differentiation engines (forward & reverse)
- Automatic vectorization
- Backend-agnostic code (run on CPU, GPU, and TPU)
- Build on top and compatible with Equinox
- Batch-Parallel Training
- Collection of Callbacks
- Composability
This package was developed as part of the APEBench paper (arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it:
@article{koehler2024apebench, title={Apebench: A benchmark for autoregressive neural emulators of pdes}, author={Koehler, Felix and Niedermayr, Simon and Westermann, R{\"u}diger and Thuerey, Nils}, journal={Advances in Neural Information Processing Systems}, volume={37}, pages={120252--120310}, year={2024} }
(Feel free to also give the project a star on GitHub if you like it.)
Here you can find the APEBench benchmark suite.
The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.
MIT, see here
fkoehler.site · GitHub @ceyron · X @felix_m_koehler · LinkedIn Felix Köhler