Apache Beam RunInference with TensorFlow and TensorFlow Hub

Run in Google Colab View source on GitHub

This notebook shows how to use the Apache Beam RunInference transform for TensorFlow with a trained model from TensorFlow Hub. Apache Beam includes built-in support for two TensorFlow model handlers: TFModelHandlerNumpy and TFModelHandlerTensor.

  • Use TFModelHandlerNumpy to run inference on models that expect a NumPy array as an input.
  • Use TFModelHandlerTensor to run inference on models expecting a tensor as an input.

For more information about using RunInference, see Get started with AI/ML pipelines in the Apache Beam documentation.

Before you begin

First, import tensorflow. To use RunInference with the TensorFlow model handler, install Apache Beam version 2.46 or later.

pipinstalltensorflow
pipinstallapache_beam[interactive]==2.46.0

Use TensorFlow Hub's trained model URL

To use TensorFlow Hub's trained model URL, pass the model URL to the model_uri field of TFModelHandler class.

importtensorflowastf
importtensorflow_hubashub
importapache_beamasbeam
# URL of the trained model from TensorFlow Hub
CLASSIFIER_URL ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
importnumpyasnp
importPIL.ImageasImage
IMAGE_RES = 224
img = tf.keras.utils.get_file(origin='https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg')
img = Image.open(img).resize((IMAGE_RES, IMAGE_RES))
img
Downloading data from https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg
1812110/1812110 [==============================] - 0s 0us/step

png

# Convert the input image to the type and dimensions required by the model.
img = np.array(img)/255.0
img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
fromapache_beam.ml.inference.tensorflow_inferenceimport TFModelHandlerTensor
fromapache_beam.ml.inference.baseimport PredictionResult
fromapache_beam.ml.inference.baseimport RunInference
fromtypingimport Iterable
model_handler = TFModelHandlerTensor(model_uri=CLASSIFIER_URL)
classPostProcessor(beam.DoFn):
"""Process the PredictionResult to get the predicted label.
 Returns predicted label.
 """
 defsetup(self):
 labels_path = tf.keras.utils.get_file(
 'ImageNetLabels.txt',
 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
 )
 self._imagenet_labels = np.array(open(labels_path).read().splitlines())
 defprocess(self, element: PredictionResult) -> Iterable[str]:
 predicted_class = np.argmax(element.inference)
 predicted_class_name = self._imagenet_labels[predicted_class]
 yield "Predicted Label: {}".format(predicted_class_name.title())
with beam.Pipeline() as p:
 _ = (p
 | "Create PCollection" >> beam.Create([img_tensor])
 | "Perform inference" >> RunInference(model_handler)
 | "Post Processing" >> beam.ParDo(PostProcessor())
 | "Print" >> beam.Map(print))
Predicted Label: Tiger Cat

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025年10月22日 UTC.