-
Notifications
You must be signed in to change notification settings - Fork 326
Conversation
Modified for Viewer
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
saikishor
commented
Oct 2, 2025
@btaba
btaba
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Added a bunch of comments
brax/training/acting.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this whole block just be
if render_fn:
io_callback(render_fn, None, state)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may get rid of the fixed overhead in your main post, JAX should be ignoring this whole block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review! A JAX Array(bool) for should_render instead of checking render_fn is None is used because users need to toggle rendering on/off during training without re-JIT. This enables real-time visualization that can be disabled mid-training to restore full training speed.
brax/training/acting.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're not passing render_fn anyways, not sure you really need should_render
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the argu render_fn to train() is removed
brax/training/agents/ppo/train.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype=bool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed line #241
brax/training/agents/ppo/train.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what ViewerWrapper is, maybe update the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to # If the environment exposes a render_fn, use it for real-time rendering during training.
As render_fn and should_render is removed fro train() args, now render_fn is at environment. render_fn which can be provided by external wrapper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, you maybe can get away without the bool should_render
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you're ignoring the arg to train(), just delete all the args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted args
change jax.bool_ to bool now render io_callback will check environment attribute "should_render"
Uh oh!
There was an error while loading. Please reload this page.
This PR adds an interface for real-time rendering in
PPO.trainandactingby adding a user-defined callback executed viajax.experimental.io_callback. The final goal is to provide a real-time viewer beside the notebook html viewer.Key Changes
Two new optional parameters are added:
render_fn: A Python callable that accepts abrax.Stateto handle the rendering logic.should_render: A boolean JAXArrayused to conditionally trigger the callback.Performance Impact
Adding the new callback introduces a minor performance. This overhead exists even when rendering is disabled (i.e.,
should_renderisFalse). The JIT compiler must account for the conditional logic required for theio_callback, which slightly alters the compiled execution path.The benchmarks below, run on an Apple M1 Max and NVIDIA 2080.
ppo_comparison_consolidated