orbit.StandardTrainer

View source on GitHub

Implements standard functionality on top of the AbstractTrainer API.

Inherits From: AbstractTrainer

orbit.StandardTrainer(
 train_dataset,
 options: Optional[orbit.StandardTrainerOptions ] = None
)

This class structures the training "inner loop" roughly as follows:

train_loop_begin()
for _ in range(num_steps):
 train_step(train_iterator)
return train_loop_end()

Calls to train_loop_begin and train_loop_end are always done in eager mode, while the loop/train_step may be implemented using tf.while and/or tf.function, as determined by the options passed to __init__.

Args

train_dataset A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
options An orbit.StandardTrainerOptions instance.

Attributes

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

train_dataset The current training dataset.
trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.

Methods

create_train_loop_fn

View source

create_train_loop_fn()

Creates a training loop from the current step function and options.

Returns
The train loop function, i.e. wrapper of multiple train steps.

train

View source

train(
 num_steps: tf.Tensor
) -> Optional[runner.Output]

Implements num_steps steps of training.

Args
num_steps The number of training steps to run. This corresponds directly to the number of calls made to train_step.

Returns
The output of train_loop_end.

train_loop_begin

View source

train_loop_begin()

Called once at the beginning of the training loop.

This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.

Note that this method is called before dataset iterator creation.

train_loop_end

View source

train_loop_end() -> Optional[runner.Output]

Called once at the end of the training loop.

This method is always called in eager mode, and is a good place to get metric results. The value returned from this function will be returned as-is from the train method implementation provided by StandardTrainer.

Returns
The function may return a dictionary of Tensors, which will be written to logs and as TensorBoard summaries. It can also be a nested dictionary, yielding a hierarchy of summary directories.

train_step

View source

@abc.abstractmethod
train_step(
 iterator
)

Implements one step of training.

What a "step" consists of is up to the implementer. When using distribution strategies, the call to this method takes place in the "cross-replica context" for generality, to allow e.g. multiple iterator dequeues and calls to strategy.run.

Note that if use_tf_function=True, all the code inside train_step should be compatible with tf.function tracing (and in particular, any state modifications involving self should be avoided). In some cases, non- tf.function compatible code can be moved to train_loop_begin or train_loop_end, which always execute eagerly.

Args
iterator A tf.nest-compatible structure of tf.data.Iterator or DistributedIterator. The structure of this input matches the structure of train_dataset as passed to __init__.

with_name_scope

@classmethod
with_name_scope(
 method
)

Decorator to automatically enter the module name scope.

classMyModule(tf.Module):
 @tf.Module.with_name_scope
 def__call__(self, x):
 if not hasattr(self, 'w'):
 self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
 return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2025年04月18日 UTC.