Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4915c84

Browse files
fixed broken url in notebook, added trained model and file to convert from coco to tfrecord
1 parent d9c1e66 commit 4915c84

File tree

4 files changed

+264
-2
lines changed

4 files changed

+264
-2
lines changed

‎.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.idea

‎Tensorflow_2_Object_Detection_Train_model.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
"metadata": {},
3535
"source": [
3636
"<table align=\"left\"><td>\n",
37-
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/TannerGilbert/Tutorials/blob/master/Tensorflow-Object-Detection-API-Train-Model/Tensorflow_2_Object_Detection_Train_model\">\n",
37+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/TannerGilbert/Tensorflow-Object-Detection-API-Train-Model/blob/master/Tensorflow_2_Object_Detection_Train_model.ipynb\">\n",
3838
" <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab\n",
3939
" </a>\n",
4040
"</td><td>\n",
41-
" <a target=\"_blank\" href=\"https://github.com/TannerGilbert/Tutorials/blob/master/Tensorflow-Object-Detection-API-Train-Model/Tensorflow_2_Object_Detection_Train_model\">\n",
41+
" <a target=\"_blank\" href=\"https://github.com/TannerGilbert/Tensorflow-Object-Detection-API-Train-Model/blob/master/Tensorflow_2_Object_Detection_Train_model.ipynb\">\n",
4242
" <img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
4343
"</td></table>"
4444
]

‎create_coco_tf_record.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
r"""Convert raw COCO dataset to TFRecord for object_detection.
17+
18+
Please note that this tool creates sharded output files.
19+
20+
Example usage:
21+
python create_coco_tf_record.py --logtostderr \
22+
--train_image_dir="${TRAIN_IMAGE_DIR}" \
23+
--test_image_dir="${TEST_IMAGE_DIR}" \
24+
--train_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \
25+
--test_annotations_file="${TEST_ANNOTATIONS_FILE}" \
26+
--output_dir="${OUTPUT_DIR}"
27+
"""
28+
from __future__ import absolute_import
29+
from __future__ import division
30+
from __future__ import print_function
31+
32+
import hashlib
33+
import io
34+
import json
35+
import os
36+
import contextlib2
37+
import numpy as np
38+
import PIL.Image
39+
40+
from pycocotools import mask
41+
import tensorflow as tf
42+
43+
from object_detection.dataset_tools import tf_record_creation_util
44+
from object_detection.utils import dataset_util
45+
from object_detection.utils import label_map_util
46+
47+
48+
flags = tf.app.flags
49+
tf.flags.DEFINE_boolean('include_masks', False,
50+
'Whether to include instance segmentations masks '
51+
'(PNG encoded) in the result. default: False.')
52+
tf.flags.DEFINE_string('train_image_dir', '',
53+
'Training image directory.')
54+
tf.flags.DEFINE_string('test_image_dir', '',
55+
'Test image directory.')
56+
tf.flags.DEFINE_string('train_annotations_file', '',
57+
'Training annotations JSON file.')
58+
tf.flags.DEFINE_string('test_annotations_file', '',
59+
'Test-dev annotations JSON file.')
60+
tf.flags.DEFINE_string('output_dir', '/tmp/', 'Output data directory.')
61+
62+
FLAGS = flags.FLAGS
63+
64+
tf.logging.set_verbosity(tf.logging.INFO)
65+
66+
67+
def create_tf_example(image,
68+
annotations_list,
69+
image_dir,
70+
category_index,
71+
include_masks=False):
72+
"""Converts image and annotations to a tf.Example proto.
73+
74+
Args:
75+
image: dict with keys:
76+
[u'license', u'file_name', u'coco_url', u'height', u'width',
77+
u'date_captured', u'flickr_url', u'id']
78+
annotations_list:
79+
list of dicts with keys:
80+
[u'segmentation', u'area', u'iscrowd', u'image_id',
81+
u'bbox', u'category_id', u'id']
82+
Notice that bounding box coordinates in the official COCO dataset are
83+
given as [x, y, width, height] tuples using absolute coordinates where
84+
x, y represent the top-left (0-indexed) corner. This function converts
85+
to the format expected by the Tensorflow Object Detection API (which is
86+
which is [ymin, xmin, ymax, xmax] with coordinates normalized relative
87+
to image size).
88+
image_dir: directory containing the image files.
89+
category_index: a dict containing COCO category information keyed
90+
by the 'id' field of each category. See the
91+
label_map_util.create_category_index function.
92+
include_masks: Whether to include instance segmentations masks
93+
(PNG encoded) in the result. default: False.
94+
Returns:
95+
example: The converted tf.Example
96+
num_annotations_skipped: Number of (invalid) annotations that were ignored.
97+
98+
Raises:
99+
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
100+
"""
101+
image_height = image['height']
102+
image_width = image['width']
103+
filename = image['file_name']
104+
image_id = image['id']
105+
106+
full_path = os.path.join(image_dir, filename)
107+
with tf.gfile.GFile(full_path, 'rb') as fid:
108+
encoded_jpg = fid.read()
109+
encoded_jpg_io = io.BytesIO(encoded_jpg)
110+
image = PIL.Image.open(encoded_jpg_io)
111+
key = hashlib.sha256(encoded_jpg).hexdigest()
112+
113+
xmin = []
114+
xmax = []
115+
ymin = []
116+
ymax = []
117+
is_crowd = []
118+
category_names = []
119+
category_ids = []
120+
area = []
121+
encoded_mask_png = []
122+
num_annotations_skipped = 0
123+
for object_annotations in annotations_list:
124+
(x, y, width, height) = tuple(object_annotations['bbox'])
125+
if width <= 0 or height <= 0:
126+
num_annotations_skipped += 1
127+
continue
128+
if x + width > image_width or y + height > image_height:
129+
num_annotations_skipped += 1
130+
continue
131+
xmin.append(float(x) / image_width)
132+
xmax.append(float(x + width) / image_width)
133+
ymin.append(float(y) / image_height)
134+
ymax.append(float(y + height) / image_height)
135+
is_crowd.append(object_annotations['iscrowd'])
136+
category_id = int(object_annotations['category_id'])
137+
category_ids.append(category_id)
138+
category_names.append(category_index[category_id]['name'].encode('utf8'))
139+
area.append(object_annotations['area'])
140+
141+
if include_masks:
142+
run_len_encoding = mask.frPyObjects(object_annotations['segmentation'],
143+
image_height, image_width)
144+
binary_mask = mask.decode(run_len_encoding)
145+
if not object_annotations['iscrowd']:
146+
binary_mask = np.amax(binary_mask, axis=2)
147+
pil_image = PIL.Image.fromarray(binary_mask)
148+
output_io = io.BytesIO()
149+
pil_image.save(output_io, format='PNG')
150+
encoded_mask_png.append(output_io.getvalue())
151+
feature_dict = {
152+
'image/height':
153+
dataset_util.int64_feature(image_height),
154+
'image/width':
155+
dataset_util.int64_feature(image_width),
156+
'image/filename':
157+
dataset_util.bytes_feature(filename.encode('utf8')),
158+
'image/source_id':
159+
dataset_util.bytes_feature(str(image_id).encode('utf8')),
160+
'image/key/sha256':
161+
dataset_util.bytes_feature(key.encode('utf8')),
162+
'image/encoded':
163+
dataset_util.bytes_feature(encoded_jpg),
164+
'image/format':
165+
dataset_util.bytes_feature('jpeg'.encode('utf8')),
166+
'image/object/bbox/xmin':
167+
dataset_util.float_list_feature(xmin),
168+
'image/object/bbox/xmax':
169+
dataset_util.float_list_feature(xmax),
170+
'image/object/bbox/ymin':
171+
dataset_util.float_list_feature(ymin),
172+
'image/object/bbox/ymax':
173+
dataset_util.float_list_feature(ymax),
174+
'image/object/class/text':
175+
dataset_util.bytes_list_feature(category_names),
176+
'image/object/is_crowd':
177+
dataset_util.int64_list_feature(is_crowd),
178+
'image/object/area':
179+
dataset_util.float_list_feature(area),
180+
}
181+
if include_masks:
182+
feature_dict['image/object/mask'] = (
183+
dataset_util.bytes_list_feature(encoded_mask_png))
184+
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
185+
return key, example, num_annotations_skipped
186+
187+
188+
def _create_tf_record_from_coco_annotations(
189+
annotations_file, image_dir, output_path, include_masks):
190+
"""Loads COCO annotation json files and converts to tf.Record format.
191+
192+
Args:
193+
annotations_file: JSON file containing bounding box annotations.
194+
image_dir: Directory containing the image files.
195+
output_path: Path to output tf.Record file.
196+
include_masks: Whether to include instance segmentations masks
197+
(PNG encoded) in the result. default: False.
198+
"""
199+
with tf.gfile.GFile(annotations_file, 'r') as fid:
200+
output_tfrecords = tf.python_io.TFRecordWriter(output_path)
201+
groundtruth_data = json.load(fid)
202+
images = groundtruth_data['images']
203+
category_index = label_map_util.create_category_index(
204+
groundtruth_data['categories'])
205+
206+
annotations_index = {}
207+
if 'annotations' in groundtruth_data:
208+
tf.logging.info(
209+
'Found groundtruth annotations. Building annotations index.')
210+
for annotation in groundtruth_data['annotations']:
211+
image_id = annotation['image_id']
212+
if image_id not in annotations_index:
213+
annotations_index[image_id] = []
214+
annotations_index[image_id].append(annotation)
215+
missing_annotation_count = 0
216+
for image in images:
217+
image_id = image['id']
218+
if image_id not in annotations_index:
219+
missing_annotation_count += 1
220+
annotations_index[image_id] = []
221+
tf.logging.info('%d images are missing annotations.',
222+
missing_annotation_count)
223+
224+
total_num_annotations_skipped = 0
225+
for idx, image in enumerate(images):
226+
if idx % 100 == 0:
227+
tf.logging.info('On image %d of %d', idx, len(images))
228+
annotations_list = annotations_index[image['id']]
229+
_, tf_example, num_annotations_skipped = create_tf_example(
230+
image, annotations_list, image_dir, category_index, include_masks)
231+
total_num_annotations_skipped += num_annotations_skipped
232+
output_tfrecords.write(tf_example.SerializeToString())
233+
tf.logging.info('Finished writing, skipped %d annotations.',
234+
total_num_annotations_skipped)
235+
236+
237+
def main(_):
238+
assert FLAGS.train_image_dir, '`train_image_dir` missing.'
239+
assert FLAGS.test_image_dir, '`test_image_dir` missing.'
240+
assert FLAGS.train_annotations_file, '`train_annotations_file` missing.'
241+
assert FLAGS.test_annotations_file, '`test_annotations_file` missing.'
242+
243+
if not tf.gfile.IsDirectory(FLAGS.output_dir):
244+
tf.gfile.MakeDirs(FLAGS.output_dir)
245+
train_output_path = os.path.join(FLAGS.output_dir, 'train.record')
246+
testdev_output_path = os.path.join(FLAGS.output_dir, 'test.record')
247+
248+
_create_tf_record_from_coco_annotations(
249+
FLAGS.train_annotations_file,
250+
FLAGS.train_image_dir,
251+
train_output_path,
252+
FLAGS.include_masks)
253+
_create_tf_record_from_coco_annotations(
254+
FLAGS.test_annotations_file,
255+
FLAGS.test_image_dir,
256+
testdev_output_path,
257+
FLAGS.include_masks)
258+
259+
260+
if __name__ == '__main__':
261+
tf.app.run()

‎training/saved_model.pb

21.1 MB
Binary file not shown.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /