Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Support add_loss (works currently for torch and tf, does NOT for jax) #542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
han-ol wants to merge 5 commits into dev
base: dev
Choose a base branch
Loading
from add-loss

Conversation

@han-ol
Copy link
Collaborator

@han-ol han-ol commented Jul 22, 2025
edited
Loading

This PR seeks to address #541.

It looks to me like we need to tweak JAXApproximator.stateless_compute_metrics for this to work in jax as well.

The other backends are already covered with just the changes in the initial commit.

EDIT: You can find an example in https://github.com/bayesflow-org/bayesflow/blob/add-loss/examples/Custom_losses_with_add_loss.ipynb

Copy link

codecov bot commented Jul 22, 2025
edited
Loading

Codecov Report

Attention: Patch coverage is 75.00000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...low/approximators/model_comparison_approximator.py 50.00% 3 Missing ⚠️
Files with missing lines Coverage Δ
bayesflow/approximators/continuous_approximator.py 91.45% <100.00%> (+0.22%) ⬆️
...low/approximators/model_comparison_approximator.py 83.90% <50.00%> (-1.30%) ⬇️

@han-ol han-ol requested review from LarsKue and vpratz and removed request for LarsKue July 22, 2025 12:17
Copy link
Collaborator Author

han-ol commented Jul 22, 2025

I added tests and a minimal example notebook.

Tests are passing on torch and tensorflow, but fail on jax.

@LarsKue since you are the architect of the stateless_compute_metrics, could you look into how we can make this work for jax?

The final section of the keras guide on custom training loops in jax proves that this can be rather straight forward, but I am unsure how to implement it in our case: https://keras.io/guides/writing_a_custom_training_loop_in_jax/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Reviewers

@vpratz vpratz Awaiting requested review from vpratz

@LarsKue LarsKue Awaiting requested review from LarsKue

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

2 participants

AltStyle によって変換されたページ (->オリジナル) /