Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 25ab5de

Browse files
refactoring and eanbling better method to set pixel value
1 parent cf6646b commit 25ab5de

File tree

3 files changed

+94
-109
lines changed

3 files changed

+94
-109
lines changed
1.07 MB
Binary file not shown.

‎dataset/label.pbtxt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
item {
22
id: 1
3-
name: 'speaker'
3+
name: 'speaker076'
44
}
55
item {
66
id: 2
7-
name: 'cup'
7+
name: 'cup026'
88
}

‎extra/create_mask_rcnn_tf_record.py

Lines changed: 92 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
r"""Convert the Oxford pet dataset to TFRecord for object_detection.
16+
r"""Convert your custom dataset to TFRecord for object_detection.
1717
18-
See: O. M. Parkhi, A. Vedaldi, A. Zisserman, C. V. Jawahar
19-
Cats and Dogs
20-
IEEE Conference on Computer Vision and Pattern Recognition, 2012
21-
http://www.robots.ox.ac.uk/~vgg/data/pets/
18+
Base of this script is create_pet_tf_record.py
19+
provided by tensorflow repository on github
20+
create_pet_tf_record.py could be found under
21+
tensorflow/models/research/object_detection/dataset_tools
2222
2323
Example usage:
24-
python object_detection/dataset_tools/create_pet_tf_record.py \
25-
--data_dir=/home/user/pet \
26-
--output_dir=/home/user/pet/output
24+
Python object_detection/dataset_tools/create_mask_rcnn_tf_record.py
25+
--data_dir=/Users/xyz/myProject/dataset --masks_dir=Annotations
26+
--images_dir=JPEGImages --output_dir=/Users/xyz/myProject/dataset
27+
--label_map_path=/Users/xyz/myProject/dataset/label.pbtxt
2728
"""
2829

2930
import hashlib
@@ -34,7 +35,6 @@
3435
import re
3536

3637
import contextlib2
37-
from lxml import etree
3838
import numpy as np
3939
import PIL.Image
4040
import tensorflow as tf
@@ -45,32 +45,28 @@
4545

4646
flags = tf.app.flags
4747
flags.DEFINE_string('data_dir', '', 'Path to root directory to dataset.')
48+
flags.DEFINE_string('images_dir', 'JPEGImages', 'Name of the directory contatining images')
49+
flags.DEFINE_string('masks_dir', 'Annotations', 'Name of the directory contatining masks')
4850
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
49-
flags.DEFINE_string('image_dir', 'JPEGImages', 'Name of the directory contatining images')
50-
flags.DEFINE_string('annotations_dir', 'Annotations', 'Name of the directory contatining Annotations')
5151
flags.DEFINE_string('label_map_path', '', 'Path to label map proto')
5252
flags.DEFINE_integer('num_shards', 1, 'Number of TFRecord shards')
5353
FLAGS = flags.FLAGS
5454

55-
# mask_pixel: dictionary containing class name and value for pixels belog to mask of each class
56-
# change as per your classes and labeling
57-
mask_pixel = {'speaker':76, 'cup':26}
55+
def image_to_tf_example(img_path,
56+
mask_path,
57+
label_map_dict,
58+
filename):
59+
"""Convert image and mask to tf.Example proto.
5860
59-
def dict_to_tf_example(filename,
60-
mask_path,
61-
label_map_dict,
62-
img_path):
63-
"""Convert XML derived dict to tf.Example proto.
64-
65-
Notice that this function normalizes the bounding box coordinates provided
66-
by the raw data.
61+
Note: that this function doesnt give correct output if an image contains
62+
more than one object from same class
6763
6864
Args:
69-
filename: name of the image
65+
img_path: String specifying subdirectory within the
66+
dataset directory holding the actual image data.
7067
mask_path: String path to PNG encoded mask.
7168
label_map_dict: A map from string label names to integers ids.
72-
image_subdirectory: String specifying subdirectory within the
73-
dataset directory holding the actual image data.
69+
filename: name of the image
7470
7571
7672
Returns:
@@ -82,8 +78,8 @@ def dict_to_tf_example(filename,
8278
with tf.gfile.GFile(img_path, 'rb') as fid:
8379
encoded_jpg = fid.read()
8480
encoded_jpg_io = io.BytesIO(encoded_jpg)
85-
image = PIL.Image.open(encoded_jpg_io)
86-
width = np.asarray(image).shape[1]
81+
image = PIL.Image.open(encoded_jpg_io)
82+
width = np.asarray(image).shape[1]
8783
height = np.asarray(image).shape[0]
8884
if image.format != 'JPEG':
8985
raise ValueError('Image format not JPEG')
@@ -103,15 +99,13 @@ def dict_to_tf_example(filename,
10399
ymaxs = []
104100
classes = []
105101
classes_text = []
106-
truncated = []
107-
poses = []
108-
difficult_obj = []
109-
masks = []
110-
111-
for k in list(mask_pixel.keys()):
112-
class_name = k
113-
nonbackground_indices_x = np.any(mask_np == mask_pixel[class_name], axis=0)
114-
nonbackground_indices_y = np.any(mask_np == mask_pixel[class_name], axis=1)
102+
encoded_mask_png_list = []
103+
104+
for key in label_map_dict.keys():
105+
class_name = key
106+
pixel_val = int(label_map_dict[class_name][1])
107+
nonbackground_indices_x = np.any(mask_np == pixel_val, axis=0)
108+
nonbackground_indices_y = np.any(mask_np == pixel_val, axis=1)
115109
nonzero_x_indices = np.where(nonbackground_indices_x)
116110
nonzero_y_indices = np.where(nonbackground_indices_y)
117111

@@ -128,103 +122,94 @@ def dict_to_tf_example(filename,
128122
ymaxs.append(ymax / height)
129123

130124
classes_text.append(class_name.encode('utf8'))
131-
classes.append(label_map_dict[class_name])
125+
classes.append(label_map_dict[class_name][0])
132126

133-
mask_remapped = (mask_np == mask_pixel[class_name]).astype(np.uint8)
134-
masks.append(mask_remapped)
127+
mask_remapped = (mask_np == pixel_val).astype(np.uint8)
128+
img = PIL.Image.fromarray(mask_remapped)
129+
output = io.BytesIO()
130+
img.save(output, format='PNG')
131+
encoded_mask_png_list.append(output.getvalue())
135132

136133
feature_dict = {
137-
'image/height': dataset_util.int64_feature(height),
138-
'image/width': dataset_util.int64_feature(width),
139-
'image/filename': dataset_util.bytes_feature(
140-
filename.encode('utf8')),
141-
'image/source_id': dataset_util.bytes_feature(
142-
filename.encode('utf8')),
143-
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
144-
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
145-
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
146-
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
147-
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
148-
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
149-
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
150-
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
134+
'image/height': dataset_util.int64_feature(height),
135+
'image/width': dataset_util.int64_feature(width),
136+
'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),
137+
'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),
138+
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
139+
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
140+
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
141+
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
142+
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
143+
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
144+
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
145+
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
151146
'image/object/class/label': dataset_util.int64_list_feature(classes),
152-
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
153-
'image/object/truncated': dataset_util.int64_list_feature(truncated),
154-
'image/object/view': dataset_util.bytes_list_feature(poses),
155-
}
147+
'image/object/mask': (dataset_util.bytes_list_feature(encoded_mask_png_list))}
148+
tf_data = tf.train.Example(features=tf.train.Features(feature=feature_dict))
149+
return tf_data
156150

157-
encoded_mask_png_list = []
158-
for mask in masks:
159-
img = PIL.Image.fromarray(mask)
160-
output = io.BytesIO()
161-
img.save(output, format='PNG')
162-
encoded_mask_png_list.append(output.getvalue())
163-
feature_dict['image/object/mask'] = (dataset_util.bytes_list_feature(encoded_mask_png_list))
164-
165-
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
166-
return example
167151

168-
169-
defcreate_tf_record(output_filename,
170-
num_shards,
152+
defcreate_tf_record(image_dir_path,
153+
mask_dir_path,
154+
output_dir_path,
171155
label_map_dict,
172-
annotations_dir,
173-
image_dir,
174-
examples):
175-
"""Creates a TFRecord file from examples.
156+
images_filename,
157+
num_shards):
158+
"""Creates a TFRecord file from data.
176159
177160
Args:
178-
output_filename: Path to where output file is saved.
179-
num_shards: Number of shards for output file.
161+
image_dir_path: Directory where image files are stored.
162+
mask_dir_path: Directory where annotation files are stored.
163+
output_dir_path: Path to where output file is saved.
180164
label_map_dict: The label map dictionary.
181-
annotations_dir: Directory where annotation files are stored.
182-
image_dir: Directory where image files are stored.
183-
examples: Examples to parse and save to tf record.
165+
images_filename: Examples to parse and save to tf record.
166+
num_shards: Number of shards for output file.
184167
"""
185168
with contextlib2.ExitStack() as tf_record_close_stack:
169+
output_filename = os.path.join(output_dir_path, 'data')
186170
output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
187171
tf_record_close_stack, output_filename, num_shards)
188-
for idx, example in enumerate(examples):
172+
for idx, filename in enumerate(images_filename):
189173
if idx % 100 == 0:
190-
logging.info('On image %d of %d', idx, len(examples))
191-
mask_path = os.path.join(annotations_dir, example + '.png')
192-
image_path = os.path.join(image_dir, example + '.jpg')
193-
174+
logging.info('On image %d of %d', idx, len(images_filename))
175+
mask_path = os.path.join(mask_dir_path, filename + '.png')
176+
image_path = os.path.join(image_dir_path, filename + '.jpg')
194177
try:
195-
tf_example = dict_to_tf_example(example,
178+
tf_example = image_to_tf_example(image_path,
196179
mask_path,
197180
label_map_dict,
198-
image_path)
181+
filename)
199182
if tf_example:
200183
shard_idx = idx % num_shards
201184
output_tfrecords[shard_idx].write(tf_example.SerializeToString())
202-
print("done")
185+
logging.info('done')
203186
except ValueError:
204-
logging.warning('Invalid example: %s, ignoring.', xml_path)
187+
logging.warning('Invalid example: %s, ignoring.', image_path)
205188

206189
def main(_):
207-
data_dir = FLAGS.data_dir
208-
train_output_path = FLAGS.output_dir
209-
image_dir = os.path.join(data_dir, FLAGS.image_dir)
210-
annotations_dir = os.path.join(data_dir, FLAGS.annotations_dir)
211-
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
190+
data_dir_path = FLAGS.data_dir
191+
images_dir_path = os.path.join(data_dir_path, FLAGS.images_dir)
192+
masks_dir_path = os.path.join(data_dir_path, FLAGS.masks_dir)
193+
tfrecord_dir_path = FLAGS.output_dir
194+
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
195+
for key in label_map_dict.keys():
196+
label_map_dict[key] = [label_map_dict[key], key[-3:]]
197+
label_map_dict[key[:-3]] = label_map_dict.pop(key)
212198

213199
logging.info('Reading from dataset.')
214-
examples_list = os.listdir(image_dir)
215-
for el in examples_list:
216-
if el[-3:] !='jpg':
217-
del examples_list[examples_list.index(el)]
218-
for el in examples_list:
219-
examples_list[examples_list.index(el)] = el[0:-4]
220-
221-
create_tf_record(train_output_path,
222-
FLAGS.num_shards,
223-
label_map_dict,
224-
annotations_dir,
225-
image_dir,
226-
examples_list)
227-
200+
images_filename = os.listdir(images_dir_path)
201+
for filename in images_filename:
202+
if filename[-3:] !='jpg':
203+
del images_filename[images_filename.index(filename)]
204+
for filename in images_filename:
205+
images_filename[images_filename.index(filename)] = filename[0:-4]
206+
207+
create_tf_record(images_dir_path,
208+
masks_dir_path,
209+
tfrecord_dir_path,
210+
label_map_dict,
211+
images_filename,
212+
FLAGS.num_shards)
228213

229214
if __name__ == '__main__':
230215
tf.app.run()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /