active ci codecov documentation version
Simulation-based inference in JAX
Sbijax is a Python library for neural simulation-based inference and
approximate Bayesian computation using JAX.
It implements recent methods, such as Simulated-annealing ABC,
Surjective Neural Likelihood Estimation, Neural Approximate Sufficient Statistics
or Consistency model posterior estimation, as well as methods to compute model
diagnostics and for visualizing posterior distributions.
Caution
Sbijax implements a slim object-oriented API with functional elements stemming from
JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.
For example, you can define a neural likelihood estimation method and generate posterior samples like this:
from jax import numpy as jnp, random as jr from sbijax import NLE from sbijax.nn import make_maf from tensorflow_probability.substrates.jax import distributions as tfd def prior_fn(): prior = tfd.JointDistributionNamed(dict( theta=tfd.Normal(jnp.zeros(2), jnp.ones(2)) ), batch_ndims=0) return prior def simulator_fn(seed, theta): p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1) y = theta["theta"] + p.sample(seed=seed) return y fns = prior_fn, simulator_fn model = NLE(fns, make_maf(2)) y_observed = jnp.array([-1.0, 1.0]) data, _ = model.simulate_data(jr.PRNGKey(1)) params, _ = model.fit(jr.PRNGKey(2), data=data) posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
More self-contained examples can be found in examples.
Documentation can be found here.
Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install from PyPI, just call the following on the command line:
pip install sbijax
To install the latest GitHub , use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.
In order to contribute:
- Clone
sbijaxand installhatchviapip install hatch, - create a new branch locally
git checkout -b feature/my-new-featureorgit checkout -b issue/fixes-bug, - implement your contribution and ideally a test case,
- test it by calling
make tests,make lintsandmake formaton the (Unix) command line, - submit a PR π
If you find our work relevant to your research, please consider citing:
@article{dirmeier2024simulation,
title={Simulation-based inference with the Python Package sbijax},
author={Dirmeier, Simon and Ulzega, Simone and Mira, Antonietta and Albert, Carlo},
journal={arXiv preprint arXiv:2409.19435},
year={2024}
}
Note
π The API of the package is heavily inspired by the excellent Pytorch-based sbi package.
Simon Dirmeier sfyrbnd @ pm me