- 
  Notifications
 
You must be signed in to change notification settings  - Fork 78
 
Support add_loss (works currently for torch and tf, does NOT for jax) #542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
 Codecov ReportAttention: Patch coverage is  
 
  | 
 
I added tests and a minimal example notebook.
Tests are passing on torch and tensorflow, but fail on jax.
@LarsKue since you are the architect of the stateless_compute_metrics, could you look into how we can make this work for jax?
The final section of the keras guide on custom training loops in jax proves that this can be rather straight forward, but I am unsure how to implement it in our case: https://keras.io/guides/writing_a_custom_training_loop_in_jax/
Uh oh!
There was an error while loading. Please reload this page.
This PR seeks to address #541.
It looks to me like we need to tweak
JAXApproximator.stateless_compute_metricsfor this to work in jax as well.The other backends are already covered with just the changes in the initial commit.
EDIT: You can find an example in https://github.com/bayesflow-org/bayesflow/blob/add-loss/examples/Custom_losses_with_add_loss.ipynb