text.MaskValuesChooser

View source on GitHub

Assigns values to the items chosen for masking.

text.MaskValuesChooser(
 vocab_size, mask_token, mask_token_rate=0.8, random_token_rate=0.1
)

Used in the notebooks

Used in the guide

MaskValuesChooser encapsulates the logic for deciding the value to assign items that where chosen for masking. The following are the behavior in the default implementation:

For mask_token_rate of the time, replace the item with the [MASK] token:

my dog is hairy -> my dog is [MASK]

For random_token_rate of the time, replace the item with a random word:

my dog is hairy -> my dog is apple

For 1 - mask_token_rate - random_token_rate of the time, keep the item unchanged:

my dog is hairy -> my dog is hairy.

The default behavior is consistent with the methodology specified in Masked LM and Masking Procedure described in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (https://arxiv.org/pdf/1810.04805.pdf).

Users may further customize this with behavior through subclassing and overriding get_mask_values().

Args

vocab_size size of vocabulary.
mask_token The id of the mask token.
mask_token_rate (optional) A float between 0 and 1 which indicates how often the mask_token is substituted for tokens selected for masking. Default is 0.8, NOTE: mask_token_rate + random_token_rate <= 1.
random_token_rate A float between 0 and 1 which indicates how often a random token is substituted for tokens selected for masking. Default is 0.1. NOTE: mask_token_rate + random_token_rate <= 1.

Attributes

mask_token

random_token_rate

vocab_size

Methods

get_mask_values

View source

get_mask_values(
 masked_lm_ids
)

Get the values used for masking, random injection or no-op.

Args
masked_lm_ids a RaggedTensor of n dimensions and dtype int32 or int64 whose values are the ids of items that have been selected for masking.

Returns
a RaggedTensor of the same dtype and shape with masked_lm_ids whose values contain either the mask token, randomly injected token or original value.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025年04月11日 UTC.