orbit.AbstractTrainer

View source on GitHub

An abstract class defining the API required for training.

orbit.AbstractTrainer(
 name=None
)

Attributes

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.

Methods

train

View source

@abc.abstractmethod
train(
 num_steps: tf.Tensor
) -> Optional[Output]

Implements num_steps steps of training.

This method will be called by the Controller to perform the "inner loop" of training. This inner loop amortizes the cost of bookkeeping associated with checkpointing, evaluation, and writing summaries. Additionally, the inner loop can be implemented (if desired) using TensorFlow's looping constructs (e.g. a for loop over a tf.range inside a tf.function), which can be necessary for getting optimal performance when running on TPU. For cases that don't require peak performance, a simple Python loop can be used instead for simplicity.

Args
num_steps The number of training steps to run. Note that it is up to the model what constitutes a "step", which may involve more than one update to model parameters (e.g., if training a GAN).

Returns
Either None, or a dictionary mapping names to Tensors or NumPy values. If a dictionary is returned, it will be written to logs and as TensorBoard summaries. The dictionary may also be nested, which will generate a hierarchy of summary directories.

with_name_scope

@classmethod
with_name_scope(
 method
)

Decorator to automatically enter the module name scope.

classMyModule(tf.Module):
 @tf.Module.with_name_scope
 def__call__(self, x):
 if not hasattr(self, 'w'):
 self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
 return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2025年04月18日 UTC.