Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

dirmeier/sbijax

Repository files navigation

sbijax

active ci codecov documentation version

Simulation-based inference in JAX

About

Sbijax is a Python library for neural simulation-based inference and approximate Bayesian computation using JAX. It implements recent methods, such as Simulated-annealing ABC, Surjective Neural Likelihood Estimation, Neural Approximate Sufficient Statistics or Consistency model posterior estimation, as well as methods to compute model diagnostics and for visualizing posterior distributions.

Caution

⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

Examples

Sbijax implements a slim object-oriented API with functional elements stemming from JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm. For example, you can define a neural likelihood estimation method and generate posterior samples like this:

from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd
def prior_fn():
 prior = tfd.JointDistributionNamed(dict(
 theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
 ), batch_ndims=0)
 return prior
def simulator_fn(seed, theta):
 p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
 y = theta["theta"] + p.sample(seed=seed)
 return y
fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))
y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install from PyPI, just call the following on the command line:

pip install sbijax

To install the latest GitHub , use:

pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone sbijax and install hatch via pip install hatch,
  2. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,
  3. implement your contribution and ideally a test case,
  4. test it by calling make tests, make lints and make format on the (Unix) command line,
  5. submit a PR πŸ™‚

Citing sbijax

If you find our work relevant to your research, please consider citing:

@article{dirmeier2024simulation,
 title={Simulation-based inference with the Python Package sbijax},
 author={Dirmeier, Simon and Ulzega, Simone and Mira, Antonietta and Albert, Carlo},
 journal={arXiv preprint arXiv:2409.19435},
 year={2024}
}

Acknowledgements

Note

πŸ“ The API of the package is heavily inspired by the excellent Pytorch-based sbi package.

Author

Simon Dirmeier sfyrbnd @ pm me

AltStyle γ«γ‚ˆγ£γ¦ε€‰ζ›γ•γ‚ŒγŸγƒšγƒΌγ‚Έ (->γ‚ͺγƒͺγ‚ΈγƒŠγƒ«) /