Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4fe6034

Browse files
committed
Updated examples
1 parent caf04f5 commit 4fe6034

24 files changed

+1680
-507
lines changed

‎docs/examples/plot_object_detection_simple.py‎ renamed to ‎docs/examples/plot_object_detection_checkpoint.py‎

Lines changed: 92 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,54 @@
11
#!/usr/bin/env python
22
# coding: utf-8
33
"""
4-
Object Detection Test
5-
=====================
4+
Object Detection From TF2 Checkpoint
5+
====================================
66
"""
77

88
# %%
9-
# This demo will take you through the steps of running an "out-of-the-box" detection model on a
10-
# collection of images.
11-
12-
# %%
13-
# Create the data directory
14-
# ~~~~~~~~~~~~~~~~~~~~~~~~~
15-
# The snippet shown below will create the ``data`` directory where all our data will be stored. The
16-
# code will create a directory structure as shown bellow:
17-
#
18-
# .. code-block:: bash
19-
#
20-
# data
21-
# ├── images
22-
# └── models
23-
#
24-
# where the ``images`` folder will contain the downlaoded test images, while ``models`` will
25-
# contain the downloaded models.
26-
import os
27-
28-
DATA_DIR = os.path.join(os.getcwd(), 'data')
29-
IMAGES_DIR = os.path.join(DATA_DIR, 'images')
30-
MODELS_DIR = os.path.join(DATA_DIR, 'models')
31-
for dir in [DATA_DIR, IMAGES_DIR, MODELS_DIR]:
32-
if not os.path.exists(dir):
33-
os.mkdir(dir)
9+
# This demo will take you through the steps of running an "out-of-the-box" TensorFlow 2 compatible
10+
# detection model on a collection of images. More specifically, in this example we will be using
11+
# the `Checkpoint Format <https://www.tensorflow.org/guide/checkpoint>`__ to load the model.
3412

3513
# %%
3614
# Download the test images
3715
# ~~~~~~~~~~~~~~~~~~~~~~~~
3816
# First we will download the images that we will use throughout this tutorial. The code snippet
3917
# shown bellow will download the test images from the `TensorFlow Model Garden <https://github.com/tensorflow/models/tree/master/research/object_detection/test_images>`_
4018
# and save them inside the ``data/images`` folder.
41-
import urllib.request
19+
import os
20+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1)
21+
import pathlib
22+
import tensorflow as tf
4223

43-
IMAGE_FILENAMES = ['image1.jpg', 'image2.jpg']
44-
IMAGES_DOWNLOAD_BASE = \
45-
'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/test_images/'
24+
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow logging (2)
4625

47-
for image_filename in IMAGE_FILENAMES:
26+
# Enable GPU dynamic memory allocation
27+
gpus = tf.config.experimental.list_physical_devices('GPU')
28+
for gpu in gpus:
29+
tf.config.experimental.set_memory_growth(gpu, True)
4830

49-
image_path = os.path.join(IMAGES_DIR, image_filename)
31+
def download_images():
32+
base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/test_images/'
33+
filenames = ['image1.jpg', 'image2.jpg']
34+
image_paths = []
35+
for filename in filenames:
36+
image_path = tf.keras.utils.get_file(fname=filename,
37+
origin=base_url + filename,
38+
untar=False)
39+
image_path = pathlib.Path(image_path)
40+
image_paths.append(str(image_path))
41+
return image_paths
5042

51-
# Download image
52-
if not os.path.exists(image_path):
53-
print('Downloading {}... '.format(image_filename), end='')
54-
urllib.request.urlretrieve(IMAGES_DOWNLOAD_BASE + image_filename, image_path)
55-
print('Done')
43+
IMAGE_PATHS = download_images()
5644

5745

5846
# %%
5947
# Download the model
6048
# ~~~~~~~~~~~~~~~~~~
61-
# The code snippet shown below is used to download the object detection model checkpoint file,
62-
# as well as the labels file (.pbtxt) which contains a list of strings used to add the correct
63-
# label to each detection (e.g. person). Once downloaded the files will be stored under the
64-
# ``data/models`` folder.
65-
#
66-
# The particular detection algorithm we will use is the `CenterNet HourGlass104 1024x1024`. More
67-
# models can be found in the `TensorFlow 2 Detection Model Zoo <https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md>`_.
49+
# The code snippet shown below is used to download the pre-trained object detection model we shall
50+
# use to perform inference. The particular detection algorithm we will use is the
51+
# `CenterNet HourGlass104 1024x1024`. More models can be found in the `TensorFlow 2 Detection Model Zoo <https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md>`_.
6852
# To use a different model you will need the URL name of the specific model. This can be done as
6953
# follows:
7054
#
@@ -76,62 +60,63 @@
7660
#
7761
# For example, the download link for the model used below is: ``download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_1024x1024_coco17_tpu-32.tar.gz``
7862

79-
import tarfile
80-
8163
# Download and extract model
64+
def download_model(model_name, model_date):
65+
base_url = 'http://download.tensorflow.org/models/object_detection/tf2/'
66+
model_file = model_name + '.tar.gz'
67+
model_dir = tf.keras.utils.get_file(fname=model_name,
68+
origin=base_url + model_date + '/' + model_file,
69+
untar=True)
70+
return str(model_dir)
71+
8272
MODEL_DATE = '20200711'
8373
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
84-
MODEL_TAR_FILENAME = MODEL_NAME + '.tar.gz'
85-
MODELS_DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/tf2/'
86-
MODEL_DOWNLOAD_LINK = MODELS_DOWNLOAD_BASE + MODEL_DATE + '/' + MODEL_TAR_FILENAME
87-
PATH_TO_MODEL_TAR = os.path.join(MODELS_DIR, MODEL_TAR_FILENAME)
88-
PATH_TO_CKPT = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, 'checkpoint/'))
89-
PATH_TO_CFG = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, 'pipeline.config'))
90-
if not os.path.exists(PATH_TO_CKPT):
91-
print('Downloading model. This may take a while... ', end='')
92-
urllib.request.urlretrieve(MODEL_DOWNLOAD_LINK, PATH_TO_MODEL_TAR)
93-
tar_file = tarfile.open(PATH_TO_MODEL_TAR)
94-
tar_file.extractall(MODELS_DIR)
95-
tar_file.close()
96-
os.remove(PATH_TO_MODEL_TAR)
97-
print('Done')
74+
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)
75+
76+
# %%
77+
# Download the labels
78+
# ~~~~~~~~~~~~~~~~~~~
79+
# The coode snippet shown below is used to download the labels file (.pbtxt) which contains a list
80+
# of strings used to add the correct label to each detection (e.g. person). Since the pre-trained
81+
# model we will use has been trained on the COCO dataset, we will need to download the labels file
82+
# corresponding to this dataset, named ``mscoco_label_map.pbtxt``. A full list of the labels files
83+
# included in the TensorFlow Models Garden can be found `here <https://github.com/tensorflow/models/tree/master/research/object_detection/data>`__.
9884

9985
# Download labels file
86+
def download_labels(filename):
87+
base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'
88+
label_dir = tf.keras.utils.get_file(fname=filename,
89+
origin=base_url + filename,
90+
untar=False)
91+
label_dir = pathlib.Path(label_dir)
92+
return str(label_dir)
93+
10094
LABEL_FILENAME = 'mscoco_label_map.pbtxt'
101-
LABELS_DOWNLOAD_BASE = \
102-
'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'
103-
PATH_TO_LABELS = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, LABEL_FILENAME))
104-
if not os.path.exists(PATH_TO_LABELS):
105-
print('Downloading label file... ', end='')
106-
urllib.request.urlretrieve(LABELS_DOWNLOAD_BASE + LABEL_FILENAME, PATH_TO_LABELS)
107-
print('Done')
95+
PATH_TO_LABELS = download_labels(LABEL_FILENAME)
10896

10997
# %%
11098
# Load the model
11199
# ~~~~~~~~~~~~~~
112100
# Next we load the downloaded model
113-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1)
114-
import tensorflow as tf
101+
import time
115102
from object_detection.utils import label_map_util
116103
from object_detection.utils import config_util
117104
from object_detection.utils import visualization_utils as viz_utils
118105
from object_detection.builders import model_builder
119106

120-
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow logging (2)
107+
PATH_TO_CFG = PATH_TO_MODEL_DIR + "/pipeline.config"
108+
PATH_TO_CKPT = PATH_TO_MODEL_DIR + "/checkpoint"
121109

122-
# Enable GPU dynamic memory allocation
123-
gpus = tf.config.experimental.list_physical_devices('GPU')
124-
for gpu in gpus:
125-
tf.config.experimental.set_memory_growth(gpu, True)
110+
print('Loading model... ', end='')
111+
start_time = time.time()
126112

127113
# Load pipeline config and build a detection model
128114
configs = config_util.get_configs_from_pipeline_file(PATH_TO_CFG)
129115
model_config = configs['model']
130116
detection_model = model_builder.build(model_config=model_config, is_training=False)
131117

132118
# Restore checkpoint
133-
ckpt = tf.compat.v2.train.Checkpoint(
134-
model=detection_model)
119+
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
135120
ckpt.restore(os.path.join(PATH_TO_CKPT, 'ckpt-0')).expect_partial()
136121

137122
@tf.function
@@ -142,8 +127,11 @@ def detect_fn(image):
142127
prediction_dict = detection_model.predict(image, shapes)
143128
detections = detection_model.postprocess(prediction_dict, shapes)
144129

145-
return detections, prediction_dict, tf.reshape(shapes, [-1])
130+
return detections
146131

132+
end_time = time.time()
133+
elapsed_time = end_time - start_time
134+
print('Done! Took {} seconds'.format(elapsed_time))
147135

148136
# %%
149137
# Load label map data (for plotting)
@@ -172,7 +160,6 @@ def detect_fn(image):
172160
# * Print out `detections['detection_boxes']` and try to match the box locations to the boxes in the image. Notice that coordinates are given in normalized form (i.e., in the interval [0, 1]).
173161
# * Set ``min_score_thresh`` to other values (between 0 and 1) to allow more detections in or to filter out more detections.
174162
import numpy as np
175-
from six import BytesIO
176163
from PIL import Image
177164
import matplotlib.pyplot as plt
178165
import warnings
@@ -191,18 +178,13 @@ def load_image_into_numpy_array(path):
191178
Returns:
192179
uint8 numpy array with shape (img_height, img_width, 3)
193180
"""
194-
img_data = tf.io.gfile.GFile(path, 'rb').read()
195-
image = Image.open(BytesIO(img_data))
196-
(im_width, im_height) = image.size
197-
return np.array(image.getdata()).reshape(
198-
(im_height, im_width, 3)).astype(np.uint8)
181+
return np.array(Image.open(path))
199182

200183

201-
for image_filename in IMAGE_FILENAMES:
184+
for image_path in IMAGE_PATHS:
202185

203-
print('Running inference for {}... '.format(image_filename), end='')
186+
print('Running inference for {}... '.format(image_path), end='')
204187

205-
image_path = os.path.join(IMAGES_DIR, image_filename)
206188
image_np = load_image_into_numpy_array(image_path)
207189

208190
# Things to try:
@@ -213,23 +195,34 @@ def load_image_into_numpy_array(path):
213195
# image_np = np.tile(
214196
# np.mean(image_np, 2, keepdims=True), (1, 1, 3)).astype(np.uint8)
215197

216-
input_tensor = tf.convert_to_tensor(
217-
np.expand_dims(image_np, 0), dtype=tf.float32)
218-
detections, predictions_dict, shapes = detect_fn(input_tensor)
198+
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
199+
200+
detections = detect_fn(input_tensor)
201+
202+
# All outputs are batches tensors.
203+
# Convert to numpy arrays, and take index [0] to remove the batch dimension.
204+
# We're only interested in the first num_detections.
205+
num_detections = int(detections.pop('num_detections'))
206+
detections = {key: value[0, :num_detections].numpy()
207+
for key, value in detections.items()}
208+
detections['num_detections'] = num_detections
209+
210+
# detection_classes should be ints.
211+
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)
219212

220213
label_id_offset = 1
221214
image_np_with_detections = image_np.copy()
222215

223216
viz_utils.visualize_boxes_and_labels_on_image_array(
224-
image_np_with_detections,
225-
detections['detection_boxes'][0].numpy(),
226-
(detections['detection_classes'][0].numpy() +label_id_offset).astype(int),
227-
detections['detection_scores'][0].numpy(),
228-
category_index,
229-
use_normalized_coordinates=True,
230-
max_boxes_to_draw=200,
231-
min_score_thresh=.30,
232-
agnostic_mode=False)
217+
image_np_with_detections,
218+
detections['detection_boxes'],
219+
detections['detection_classes']+label_id_offset,
220+
detections['detection_scores'],
221+
category_index,
222+
use_normalized_coordinates=True,
223+
max_boxes_to_draw=200,
224+
min_score_thresh=.30,
225+
agnostic_mode=False)
233226

234227
plt.figure()
235228
plt.imshow(image_np_with_detections)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /