Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 383f9b2

Browse files
add create_mask_rcnn_tf_record_cv.py
add create_mask_rcnn_tf_record_cv.py, according to discussions in issue #13. Add sample image as well
1 parent b89745c commit 383f9b2

File tree

3 files changed

+246
-0
lines changed

3 files changed

+246
-0
lines changed
29.5 KB
Loading[フレーム]

‎extra/create_mask_rcnn_tf_record_cv.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
r"""Convert the Oxford pet dataset to TFRecord for object_detection.
17+
18+
See: O. M. Parkhi, A. Vedaldi, A. Zisserman, C. V. Jawahar
19+
Cats and Dogs
20+
IEEE Conference on Computer Vision and Pattern Recognition, 2012
21+
http://www.robots.ox.ac.uk/~vgg/data/pets/
22+
23+
Example usage:
24+
python object_detection/dataset_tools/create_pet_tf_record.py \
25+
--data_dir=/home/user/pet \
26+
--output_dir=/home/user/pet/output
27+
"""
28+
29+
import hashlib
30+
import io
31+
import logging
32+
import os
33+
import random
34+
import re
35+
import cv2
36+
37+
import contextlib2
38+
from lxml import etree
39+
import numpy as np
40+
import PIL.Image
41+
import tensorflow as tf
42+
43+
from object_detection.dataset_tools import tf_record_creation_util
44+
from object_detection.utils import dataset_util
45+
from object_detection.utils import label_map_util
46+
47+
flags = tf.app.flags
48+
flags.DEFINE_string('data_dir', '', 'Path to root directory to dataset.')
49+
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
50+
flags.DEFINE_string('image_dir', 'JPEGImages', 'Name of the directory contatining images')
51+
flags.DEFINE_string('annotations_dir', 'Annotations', 'Name of the directory contatining Annotations')
52+
flags.DEFINE_string('label_map_path', '', 'Path to label map proto')
53+
flags.DEFINE_integer('num_shards', 1, 'Number of TFRecord shards')
54+
FLAGS = flags.FLAGS
55+
56+
# mask_pixel: dictionary containing class name and value for pixels belog to mask of each class
57+
# change as per your classes and labeling
58+
mask_pixel = {'balloon':[119,76,194,117,84]}
59+
60+
def dict_to_tf_example(filename,
61+
mask_path,
62+
label_map_dict,
63+
img_path):
64+
"""Convert XML derived dict to tf.Example proto.
65+
66+
Notice that this function normalizes the bounding box coordinates provided
67+
by the raw data.
68+
69+
Args:
70+
filename: name of the image
71+
mask_path: String path to PNG encoded mask.
72+
label_map_dict: A map from string label names to integers ids.
73+
image_subdirectory: String specifying subdirectory within the
74+
dataset directory holding the actual image data.
75+
76+
77+
Returns:
78+
example: The converted tf.Example.
79+
80+
Raises:
81+
ValueError: if the image pointed to by filename is not a valid JPEG
82+
"""
83+
with tf.gfile.GFile(img_path, 'rb') as fid:
84+
encoded_jpg = fid.read()
85+
encoded_jpg_io = io.BytesIO(encoded_jpg)
86+
image = PIL.Image.open(encoded_jpg_io)
87+
width = np.asarray(image).shape[1]
88+
height = np.asarray(image).shape[0]
89+
if image.format != 'JPEG':
90+
raise ValueError('Image format not JPEG')
91+
key = hashlib.sha256(encoded_jpg).hexdigest()
92+
93+
with tf.gfile.GFile(mask_path, 'rb') as fid:
94+
encoded_mask_png = fid.read()
95+
encoded_png_io = io.BytesIO(encoded_mask_png)
96+
mask = PIL.Image.open(encoded_png_io)
97+
mask_np = np.asarray(mask.convert('L'))
98+
if mask.format != 'PNG':
99+
raise ValueError('Mask format not PNG')
100+
101+
xmins = []
102+
ymins = []
103+
xmaxs = []
104+
ymaxs = []
105+
classes = []
106+
classes_text = []
107+
truncated = []
108+
poses = []
109+
difficult_obj = []
110+
masks = []
111+
112+
cv2.imshow("origin", mask_np)
113+
cv2.imwrite('origin.png', mask_np)
114+
115+
for k in list(mask_pixel.keys()):
116+
class_name = k
117+
118+
pixel_vals = mask_pixel[class_name]
119+
120+
for pixel_val in pixel_vals:
121+
print('for pixel val#:', pixel_val)
122+
123+
mask_copy = mask_np.copy()
124+
mask_copy[mask_np == pixel_val] = 255
125+
ret,thresh = cv2.threshold(mask_copy, 254, 255, cv2.THRESH_BINARY)
126+
(_, conts, _) = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
127+
128+
index = 0
129+
if conts != None:
130+
for c in conts:
131+
#rect = cv2.boundingRect(c)
132+
x, y, w, h = cv2.boundingRect(c)
133+
xmin = float(x)
134+
xmax = float(x+w)
135+
ymin = float(y)
136+
ymax = float(y+h)
137+
xmins.append(xmin / width)
138+
ymins.append(ymin / height)
139+
xmaxs.append(xmax / width)
140+
ymaxs.append(ymax / height)
141+
print(filename, 'bounding box for', class_name, xmin, xmax, ymin, ymax)
142+
143+
classes_text.append(class_name.encode('utf8'))
144+
classes.append(label_map_dict[class_name])
145+
146+
mask_np_black = mask_np*0
147+
cv2.drawContours(mask_np_black, [c], -1, (255,255,255), cv2.FILLED)
148+
149+
mask_remapped = (mask_np_black == 255).astype(np.uint8)
150+
masks.append(mask_remapped)
151+
152+
feature_dict = {
153+
'image/height': dataset_util.int64_feature(height),
154+
'image/width': dataset_util.int64_feature(width),
155+
'image/filename': dataset_util.bytes_feature(
156+
filename.encode('utf8')),
157+
'image/source_id': dataset_util.bytes_feature(
158+
filename.encode('utf8')),
159+
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
160+
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
161+
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
162+
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
163+
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
164+
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
165+
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
166+
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
167+
'image/object/class/label': dataset_util.int64_list_feature(classes),
168+
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
169+
'image/object/truncated': dataset_util.int64_list_feature(truncated),
170+
'image/object/view': dataset_util.bytes_list_feature(poses),
171+
}
172+
173+
encoded_mask_png_list = []
174+
for mask in masks:
175+
img = PIL.Image.fromarray(mask)
176+
output = io.BytesIO()
177+
img.save(output, format='PNG')
178+
encoded_mask_png_list.append(output.getvalue())
179+
feature_dict['image/object/mask'] = (dataset_util.bytes_list_feature(encoded_mask_png_list))
180+
181+
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
182+
return example
183+
184+
185+
def create_tf_record(output_filename,
186+
num_shards,
187+
label_map_dict,
188+
annotations_dir,
189+
image_dir,
190+
examples):
191+
"""Creates a TFRecord file from examples.
192+
193+
Args:
194+
output_filename: Path to where output file is saved.
195+
num_shards: Number of shards for output file.
196+
label_map_dict: The label map dictionary.
197+
annotations_dir: Directory where annotation files are stored.
198+
image_dir: Directory where image files are stored.
199+
examples: Examples to parse and save to tf record.
200+
"""
201+
with contextlib2.ExitStack() as tf_record_close_stack:
202+
output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
203+
tf_record_close_stack, output_filename, num_shards)
204+
for idx, example in enumerate(examples):
205+
if idx % 100 == 0:
206+
logging.info('On image %d of %d', idx, len(examples))
207+
mask_path = os.path.join(annotations_dir, example + '.png')
208+
image_path = os.path.join(image_dir, example + '.jpg')
209+
210+
try:
211+
tf_example = dict_to_tf_example(example,
212+
mask_path,
213+
label_map_dict,
214+
image_path)
215+
if tf_example:
216+
shard_idx = idx % num_shards
217+
output_tfrecords[shard_idx].write(tf_example.SerializeToString())
218+
print("done")
219+
except ValueError:
220+
logging.warning('Invalid example: %s, ignoring.', xml_path)
221+
222+
def main(_):
223+
data_dir = FLAGS.data_dir
224+
train_output_path = FLAGS.output_dir
225+
image_dir = os.path.join(data_dir, FLAGS.image_dir)
226+
annotations_dir = os.path.join(data_dir, FLAGS.annotations_dir)
227+
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
228+
229+
logging.info('Reading from dataset.')
230+
examples_list = os.listdir(image_dir)
231+
for el in examples_list:
232+
if el[-3:] !='jpg':
233+
del examples_list[examples_list.index(el)]
234+
for el in examples_list:
235+
examples_list[examples_list.index(el)] = el[0:-4]
236+
237+
create_tf_record(train_output_path,
238+
FLAGS.num_shards,
239+
label_map_dict,
240+
annotations_dir,
241+
image_dir,
242+
examples_list)
243+
244+
245+
if __name__ == '__main__':
246+
tf.app.run()
695 KB
Loading[フレーム]

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /