Warning: This project is deprecated. TensorFlow Addons has stopped development, The project will only be providing minimal maintenance releases until May 2024. See the full announcement here or on github.

tfa.seq2seq.TrainingSampler

View source on GitHub

A training sampler that simply reads its inputs.

Inherits From: Sampler

tfa.seq2seq.TrainingSampler(
 time_major: bool = False
)

Returned sample_ids are the argmax of the RNN output logits.

Args

time_major Python bool. Whether the tensors in inputs are time major. If False (default), they are assumed to be batch major.

Raises

ValueError if sequence_length is not a 1D tensor or mask is not a 2D boolean tensor.

Attributes

batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtype DType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

initialize(
 inputs, sequence_length=None, mask=None
)

Initialize the TrainSampler.

Args
inputs A (structure of) input tensors.
sequence_length An int32 vector tensor.
mask A boolean 2D tensor.

Returns
(finished, next_inputs), a tuple of two items. The first item is a boolean vector to indicate whether the item in the batch has finished. The second item is the first slide of input data based on the timestep dimension (usually the second dim of the input).

next_inputs

View source

next_inputs(
 time, outputs, state, sample_ids
)

Returns (finished, next_inputs, next_state).

sample

View source

sample(
 time, outputs, state
)

Returns sample_ids.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2023年05月25日 UTC.