tfm.nlp.ops.SamplingModule

View source on GitHub

Implementation for sampling strategies (go/decoding-tf-nlp).

tfm.nlp.ops.SamplingModule(
 symbols_to_logits_fn,
 vocab_size: int,
 max_decode_length: int,
 eos_id: int,
 padded_decode: bool,
 length_normalization_fn: Optional[Callable[[int, tf.DType], float]] = None,
 top_k=0,
 top_p=1.0,
 sample_temperature=0.0,
 enable_greedy: bool = True,
 dtype: tf.DType = tf.float32,
 decoding_name: Optional[str] = None,
 extra_cache_output: bool = False
)

Methods

generate

View source

generate(
 initial_ids: tf.Tensor,
 initial_cache: Dict[str, tf.Tensor],
 initial_log_probs: Optional[tf.Tensor] = None
) -> Output

Implements the decoding strategy (beam_search or sampling).

Args
initial_ids initial ids to pass into the symbols_to_logits_fn. int tensor with shape [batch_size, 1]
initial_cache dictionary for caching model outputs from previous step.
initial_log_probs Optionally initial log probs if there is a prefix sequence we want to start to decode from.

Returns
Tuple of tensors representing finished_sequence: shape [batch, max_seq_length] finished_scores: [batch] first_cache: The cache after init token

inf

View source

inf()

Returns a value close to infinity, but is still finite in dtype.

This is useful to get a very large value that is still zero when multiplied by zero. The floating-point "Inf" value is NaN when multiplied by zero.

Returns
A very large value.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2024年02月02日 UTC.