Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Ceyron/trainax

Repository files navigation

Learning Methodologies for Autoregressive Neural Emulators.

PyPI Tests docs-latest Changelog License

InstallationDocumentationQuickstartBackgroundFeaturesCitation

Convenience abstractions using optax to train neural networks to autoregressively emulate time-dependent problems taking care of trajectory subsampling and offering a wide range of training methodologies (regarding unrolling length and including differentiable physics).

Installation

pip install trainax

Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.

Documentation

The documentation is available at fkoehler.site/trainax.

Quickstart

Train a kernel size 2 linear convolution (no bias) to become an emulator for the 1D advection problem.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax # pip install optax
import trainax as tx
CFL = -0.75
ref_data = tx.sample_data.advection_1d_periodic(
 cfl = CFL,
 key = jax.random.PRNGKey(0),
)
linear_conv_kernel_2 = eqx.nn.Conv1d(
 1, 1, 2,
 padding="SAME", padding_mode="CIRCULAR", use_bias=False,
 key=jax.random.PRNGKey(73)
)
sup_1_trainer, sup_5_trainer, sup_20_trainer = (
 tx.trainer.SupervisedTrainer(
 ref_data,
 num_rollout_steps=r,
 optimizer=optax.adam(1e-2),
 num_training_steps=1000,
 batch_size=32,
 )
 for r in (1, 5, 20)
)
sup_1_conv, sup_1_loss_history = sup_1_trainer(
 linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_5_conv, sup_5_loss_history = sup_5_trainer(
 linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_20_conv, sup_20_loss_history = sup_20_trainer(
 linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
FOU_STENCIL = jnp.array([1+CFL, -CFL])
print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL)) # 0.033
print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL)) # 0.025
print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL)) # 0.017

Increasing the supervised unrolling steps during training makes the learned stencil come closer to the numerical FOU stencil.

Background

After the discretization of space and time, the simulation of a time-dependent partial differential equation amounts to the repeated application of a simulation operator $\mathcal{P}h$. Here, we are interested in imitating/emulating this physical/numerical operator with a neural network $f\theta$. This repository is concerned with an abstract implementation of all ways we can frame a learning problem to inject "knowledge" from $\mathcal{P}h$ into $f\theta$.

Assume we have a distribution of initial conditions $\mathcal{Q}$ from which we sample $S$ initial states, $u^{[0]} \propto \mathcal{Q}$. Then, we can save them in an array of shape $(S, C, *N)$ (with C channels and an arbitrary number of spatial axes of dimension N) and repeatedly apply $\mathcal{P}$ to obtain the training trajectory of shape $(S, T+1, C, *N)$.

For a one-step supervised learning task, we substack the training trajectory into windows of size 2ドル$ and merge the two leftover batch axes to get a data array of shape $(S \cdot T, 2, N)$ that can be used in supervised learning scenario

$$ L(\theta) = \mathbb{E}_{(u^{[0]}, u^{[1]}) \sim \mathcal{Q}} \left[ l\left( f_\theta(u^{[0]}), u^{[1]} \right) \right] $$

where $l$ is a time-level loss. In the easiest case $l = \text{MSE}$.

Trainax supports way more than just one-step supervised learning, e.g., to train with unrolled steps, to include the reference simulator $\mathcal{P}_h$ in training, train on residuum conditions instead of resolved reference states, cut and modify the gradient flow, etc.

Features

  • Wide collection of unrolled training methodologies:
    • Supervised
    • Diverted Chain
    • Mix Chain
    • Residuum
  • Based on JAX:
    • One of the best Automatic Differentiation engines (forward & reverse)
    • Automatic vectorization
    • Backend-agnostic code (run on CPU, GPU, and TPU)
  • Build on top and compatible with Equinox
  • Batch-Parallel Training
  • Collection of Callbacks
  • Composability

Citation

This package was developed as part of the APEBench paper (arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it:

@article{koehler2024apebench,
 title={Apebench: A benchmark for autoregressive neural emulators of pdes},
 author={Koehler, Felix and Niedermayr, Simon and Westermann, R{\"u}diger and Thuerey, Nils},
 journal={Advances in Neural Information Processing Systems},
 volume={37},
 pages={120252--120310},
 year={2024}
}

(Feel free to also give the project a star on GitHub if you like it.)

Here you can find the APEBench benchmark suite.

Funding

The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.

License

MIT, see here


fkoehler.site · GitHub @ceyron · X @felix_m_koehler · LinkedIn Felix Köhler

Packages

Contributors

Languages

AltStyle によって変換されたページ (->オリジナル) /