tfm.vision.MaskRCNNTask

View source on GitHub

A single-replica view of training procedure.

Inherits From: Task

View aliases

Main aliases

tfm.vision.maskrcnn.MaskRCNNTask

tfm.vision.MaskRCNNTask(
 params, logging_dir: Optional[str] = None, name: Optional[str] = None
)

Mask R-CNN task provides artifacts for training/evalution procedures, including loading/iterating over Datasets, initializing the model, calculating the loss, post-processing, and customized metrics with reduction.

Args

params the task configuration instance, which can be any of dataclass, ConfigDict, namedtuple, etc.
logging_dir a string pointing to where the model, summaries etc. will be saved. You can also write additional stuff in this directory.
name the task name.

Attributes

logging_dir

task_config

Methods

aggregate_logs

View source

aggregate_logs(
 state: Optional[Any] = None, step_outputs: Optional[Dict[str, Any]] = None
) -> Optional[Any]

Optional aggregation over logs returned from a validation step.

build_inputs

View source

build_inputs(
 params: tfm.vision.configs.maskrcnn.DataConfig ,
 input_context: Optional[tf.distribute.InputContext] = None,
 dataset_fn: Optional[dataset_fn_lib.PossibleDatasetType] = None
) -> tf.data.Dataset

Builds input dataset.

build_losses

View source

build_losses(
 outputs: Mapping[str, Any],
 labels: Mapping[str, Any],
 aux_losses: Optional[Any] = None
) -> Dict[str, tf.Tensor]

Builds Mask R-CNN losses.

build_metrics

View source

build_metrics(
 training: bool = True
)

Builds detection metrics.

build_model

View source

build_model()

Builds Mask R-CNN model.

create_optimizer

View source

@classmethod
create_optimizer(
 optimizer_config: tfm.optimization.OptimizationConfig ,
 runtime_config: Optional[tfm.core.base_task.RuntimeConfig ] = None,
 dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig ] = None
)

Creates an TF optimizer from configurations.

Args
optimizer_config the parameters of the Optimization settings.
runtime_config the parameters of the runtime.
dp_config the parameter of differential privacy.

Returns
A tf.optimizers.Optimizer object.

inference_step

View source

inference_step(
 inputs, model: tf.keras.Model
)

Performs the forward step.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the keras.Model.

Returns
Model outputs.

initialize

View source

initialize(
 model: tf.keras.Model
)

Loads pretrained checkpoint.

process_compiled_metrics

View source

process_compiled_metrics(
 compiled_metrics, labels, model_outputs
)

Process and update compiled_metrics.

call when using compile/fit API.

Args
compiled_metrics the compiled metrics (model.compiled_metrics).
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.

process_metrics

View source

process_metrics(
 metrics, labels, model_outputs, **kwargs
)

Process and update metrics.

Called when using custom training loop API.

Args
metrics a nested structure of metrics objects. The return of function self.build_metrics.
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.
**kwargs other args.

reduce_aggregated_logs

View source

reduce_aggregated_logs(
 aggregated_logs: Dict[str, Any], global_step: Optional[tf.Tensor] = None
) -> Dict[str, tf.Tensor]

Optional reduce of aggregated logs over validation steps.

train_step

View source

train_step(
 inputs: Tuple[Any, Any],
 model: tf.keras.Model,
 optimizer: tf.keras.optimizers.Optimizer,
 metrics: Optional[List[Any]] = None
)

Does forward and backward.

Args
inputs a dictionary of input tensors.
model the model, forward pass definition.
optimizer the optimizer for this training step.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

validation_step

View source

validation_step(
 inputs: Tuple[Any, Any],
 model: tf.keras.Model,
 metrics: Optional[List[Any]] = None
)

Validatation step.

Args
inputs a dictionary of input tensors.
model the keras.Model.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

Class Variables

loss 'loss'

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2024年02月02日 UTC.