Object Detection

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook See TF Hub models

This Colab demonstrates use of a TF-Hub module trained to perform object detection.

Setup

Imports and function definitions

# For running inference on the TF-Hub module.
importtensorflowastf
importtensorflow_hubashub
# For downloading the image.
importmatplotlib.pyplotasplt
importtempfile
fromsix.moves.urllib.requestimport urlopen
fromsiximport BytesIO
# For drawing onto the image.
importnumpyasnp
fromPILimport Image
fromPILimport ImageColor
fromPILimport ImageDraw
fromPILimport ImageFont
fromPILimport ImageOps
# For measuring the inference time.
importtime
# Print Tensorflow version
print(tf.__version__)
# Check available GPU devices.
print("The following GPU devices are available: %s" % tf.test.gpu_device_name())

2.16.1
The following GPU devices are available:
2024年03月09日 13:48:48.238338: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Example use

Helper functions for downloading images and for visualization.

Visualization code adapted from TF object detection API for the simplest required functionality.

defdisplay_image(image):
 fig = plt.figure(figsize=(20, 15))
 plt.grid(False)
 plt.imshow(image)
defdownload_and_resize_image(url, new_width=256, new_height=256,
 display=False):
 _, filename = tempfile.mkstemp(suffix=".jpg")
 response = urlopen(url)
 image_data = response.read()
 image_data = BytesIO(image_data)
 pil_image = Image.open(image_data)
 pil_image = ImageOps.fit(pil_image, (new_width, new_height), Image.LANCZOS)
 pil_image_rgb = pil_image.convert("RGB")
 pil_image_rgb.save(filename, format="JPEG", quality=90)
 print("Image downloaded to %s." % filename)
 if display:
 display_image(pil_image)
 return filename
defdraw_bounding_box_on_image(image,
 ymin,
 xmin,
 ymax,
 xmax,
 color,
 font,
 thickness=4,
 display_str_list=()):
"""Adds a bounding box to an image."""
 draw = ImageDraw.Draw(image)
 im_width, im_height = image.size
 (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
 ymin * im_height, ymax * im_height)
 draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
 (left, top)],
 width=thickness,
 fill=color)
 # If the total height of the display strings added to the top of the bounding
 # box exceeds the top of the image, stack the strings below the bounding box
 # instead of above.
 display_str_heights = [font.getbbox(ds)[3] for ds in display_str_list]
 # Each display_str has a top and bottom margin of 0.05x.
 total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
 if top > total_display_str_height:
 text_bottom = top
 else:
 text_bottom = top + total_display_str_height
 # Reverse list and print from bottom to top.
 for display_str in display_str_list[::-1]:
 bbox = font.getbbox(display_str)
 text_width, text_height = bbox[2], bbox[3]
 margin = np.ceil(0.05 * text_height)
 draw.rectangle([(left, text_bottom - text_height - 2 * margin),
 (left + text_width, text_bottom)],
 fill=color)
 draw.text((left + margin, text_bottom - text_height - margin),
 display_str,
 fill="black",
 font=font)
 text_bottom -= text_height - 2 * margin
defdraw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):
"""Overlay labeled boxes on an image with formatted scores and label names."""
 colors = list(ImageColor.colormap.values())
 try:
 font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf",
 25)
 except IOError:
 print("Font not found, using default font.")
 font = ImageFont.load_default()
 for i in range(min(boxes.shape[0], max_boxes)):
 if scores[i] >= min_score:
 ymin, xmin, ymax, xmax = tuple(boxes[i])
 display_str = "{}: {}%".format(class_names[i].decode("ascii"),
 int(100 * scores[i]))
 color = colors[hash(class_names[i]) % len(colors)]
 image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
 draw_bounding_box_on_image(
 image_pil,
 ymin,
 xmin,
 ymax,
 xmax,
 color,
 font,
 display_str_list=[display_str])
 np.copyto(image, np.array(image_pil))
 return image

Apply module

Load a public image from Open Images v4, save locally, and display.

# By Heiko Gorski, Source: https://commons.wikimedia.org/wiki/File:Naxos_Taverna.jpg
image_url = "https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg" 
downloaded_image_path = download_and_resize_image(image_url, 1280, 856, True)
Image downloaded to /tmpfs/tmp/tmpxk3tpk5k.jpg.

png

Pick an object detection module and apply on the downloaded image. Modules:

  • FasterRCNN+InceptionResNet V2: high accuracy,
  • ssd+mobilenet V2: small and fast.
module_handle = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1"
detector = hub.load(module_handle).signatures['default']
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
defload_img(path):
 img = tf.io.read_file(path)
 img = tf.image.decode_jpeg(img, channels=3)
 return img
defrun_detector(detector, path):
 img = load_img(path)
 converted_img = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
 start_time = time.time()
 result = detector(converted_img)
 end_time = time.time()
 result = {key:value.numpy() for key,value in result.items()}
 print("Found %d objects." % len(result["detection_scores"]))
 print("Inference time: ", end_time-start_time)
 image_with_boxes = draw_boxes(
 img.numpy(), result["detection_boxes"],
 result["detection_class_entities"], result["detection_scores"])
 display_image(image_with_boxes)
run_detector(detector, downloaded_image_path)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709992227.984468 76125 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: -2484 } dim { size: -2485 } dim { size: -2486 } dim { size: 1088 } } } inputs { dtype: DT_FLOAT shape { dim { size: -105 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -105 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } value { dtype: DT_INT32 tensor_shape { dim { size: 2 } } int_val: 17 } } device { type: "CPU" vendor: "GenuineIntel" model: "111" frequency: 2299 num_cores: 32 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 262144 l3_cache_size: 47185920 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -105 } dim { size: 17 } dim { size: 17 } dim { size: 1088 } } }
Found 100 objects.
Inference time: 39.93564057350159

png

More images

Perform inference on some additional images with time tracking.

image_urls = [
 # Source: https://commons.wikimedia.org/wiki/File:The_Coleoptera_of_the_British_islands_(Plate_125)_(8592917784).jpg
 "https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg",
 # By Américo Toledano, Source: https://commons.wikimedia.org/wiki/File:Biblioteca_Maim%C3%B3nides,_Campus_Universitario_de_Rabanales_007.jpg
 "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg/1024px-Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg",
 # Source: https://commons.wikimedia.org/wiki/File:The_smaller_British_birds_(8053836633).jpg
 "https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053836633%29.jpg",
 ]
defdetect_img(image_url):
 start_time = time.time()
 image_path = download_and_resize_image(image_url, 640, 480)
 run_detector(detector, image_path)
 end_time = time.time()
 print("Inference time:",end_time-start_time)
detect_img(image_urls[0])
Image downloaded to /tmpfs/tmp/tmp1ym56ptn.jpg.
Found 100 objects.
Inference time: 2.6752500534057617
Inference time: 2.8970775604248047

png

detect_img(image_urls[1])
Image downloaded to /tmpfs/tmp/tmpkkemjnhv.jpg.
Found 100 objects.
Inference time: 2.8261632919311523
Inference time: 3.0437986850738525

png

detect_img(image_urls[2])
Image downloaded to /tmpfs/tmp/tmpkpi72oyk.jpg.
Found 100 objects.
Inference time: 2.6092212200164795
Inference time: 2.888493776321411

png

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2024年03月09日 UTC.