Warning: This project is deprecated. TensorFlow Addons has stopped development, The project will only be providing minimal maintenance releases until May 2024. See the full announcement here or on github.

Module: tfa.seq2seq

View source on GitHub

Additional layers for sequence to sequence models.

Classes

class AttentionMechanism: Base class for attention mechanisms.

class AttentionWrapper: Wraps another RNN cell with attention.

class AttentionWrapperState: State of a tfa.seq2seq.AttentionWrapper.

class BahdanauAttention: Implements Bahdanau-style (additive) attention.

class BahdanauMonotonicAttention: Monotonic attention mechanism with Bahdanau-style energy function.

class BaseDecoder: An RNN Decoder that is based on a Keras layer.

class BasicDecoder: Basic sampling decoder for training and inference.

class BasicDecoderOutput: Outputs of a tfa.seq2seq.BasicDecoder step.

class BeamSearchDecoder: Beam search decoder.

class BeamSearchDecoderOutput: Outputs of a tfa.seq2seq.BeamSearchDecoder step.

class BeamSearchDecoderState: State of a tfa.seq2seq.BeamSearchDecoder.

class CustomSampler: Base abstract class that allows the user to customize sampling.

class Decoder: An RNN Decoder abstract interface object.

class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is finished.

class GreedyEmbeddingSampler: A inference sampler that takes the maximum from the output distribution.

class InferenceSampler: An inference sampler that uses a custom sampling function.

class LuongAttention: Implements Luong-style (multiplicative) attention scoring.

class LuongMonotonicAttention: Monotonic attention mechanism with Luong-style energy function.

class SampleEmbeddingSampler: An inference sampler that randomly samples from the output distribution.

class Sampler: Interface for implementing sampling in seq2seq decoders.

class ScheduledEmbeddingTrainingSampler: A training sampler that adds scheduled sampling.

class ScheduledOutputTrainingSampler: A training sampler that adds scheduled sampling directly to outputs.

class SequenceLoss: Weighted cross-entropy loss for a sequence of logits.

class TrainingSampler: A training sampler that simply reads its inputs.

Functions

dynamic_decode(...): Runs dynamic decoding with a decoder.

gather_tree(...): Calculates the full beams from the per-step ids and parent beam ids.

gather_tree_from_array(...): Calculates the full beams for a TensorArray.

hardmax(...): Returns batched one-hot vectors.

monotonic_attention(...): Computes monotonic attention distribution from choosing probabilities.

safe_cumprod(...): Computes cumprod of x in logspace using cumsum to avoid underflow.

sequence_loss(...): Computes the weighted cross-entropy loss for a sequence of logits.

tile_batch(...): Tiles the batch dimension of a (possibly nested structure of) tensor(s).

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2023年07月12日 UTC.