orbit.utils.OptionalSummariesFunction

View source on GitHub

Wrapper that provides versions of a function with and without summaries.

orbit.utils.OptionalSummariesFunction(
 function, **tf_function_kwargs
)

This is a utility class for implementing optimized summary recording via a two-function approach, specifically important for TPUs. Two tf.function versions of a given function are created: one with soft device placement enabled (for use on steps that require summary writing), and one with summary writing and soft device placement entirely disabled (for use on all other steps). This removes any performance impact of summaries on steps where they aren't recorded (b/148418718).

This class can be used as a base class to implement summary optimizations for a function with a specific signature. For example, to implement efficient TPU summaries for a standard train() method (as in orbit.AbstractTrainer):

classTrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction):
'''Implements a two-program approach for summaries on TPU.'''
 def__call__(self, num_steps):
 if tf.summary.should_record_summaries():
 output = self.with_summaries(tf.constant(1))
 num_steps -= 1
 if num_steps >= 1:
 output = self.without_summaries(num_steps)
 return output

This can be used directly or to implement a decorator:

deftrain_function_with_summaries(function=None, **kwargs):
 if function is not None:
 return TrainFunctionWithSummaries(function, **kwargs)
 return functools.partial(TrainFunctionWithSummaries, **kwargs)

The decorator can be applied directly to train() methods:

@train_function_with_summaries
deftrain(self, num_steps):
 ...

A similar approach approach can be implemented for functions with different signatures.

This wrapper properly handles instance methods (see __get__).

Args

function The underlying function to wrap.
**tf_function_kwargs Additional arguments to pass to tf.function.

Attributes

with_summaries A wrapped version of the underlying function with summaries enabled (using whatever the active predicate is for tf.summary.record_if), and placed inside a "soft device placement" context to enable summary recording on TPU.
without_summaries A wrapped version of the underlying function with all summary recording disabled.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2025年04月18日 UTC.