13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
16
- r"""Convert the Oxford pet dataset to TFRecord for object_detection.
16
+ r"""Convert your custom dataset to TFRecord for object_detection.
17
17
18
- See: O. M. Parkhi, A. Vedaldi, A. Zisserman, C. V. Jawahar
19
- Cats and Dogs
20
- IEEE Conference on Computer Vision and Pattern Recognition, 2012
21
- http://www.robots.ox.ac.uk/~vgg/data/pets/
18
+ Base of this script is create_pet_tf_record.py
19
+ provided by tensorflow repository on github
20
+ create_pet_tf_record.py could be found under
21
+ tensorflow/models/research/object_detection/dataset_tools
22
22
23
23
Example usage:
24
- python object_detection/dataset_tools/create_pet_tf_record.py \
25
- --data_dir=/home/user/pet \
26
- --output_dir=/home/user/pet/output
24
+ Python object_detection/dataset_tools/create_mask_rcnn_tf_record.py
25
+ --data_dir=/Users/xyz/myProject/dataset --masks_dir=Annotations
26
+ --images_dir=JPEGImages --output_dir=/Users/xyz/myProject/dataset
27
+ --label_map_path=/Users/xyz/myProject/dataset/label.pbtxt
27
28
"""
28
29
29
30
import hashlib
34
35
import re
35
36
36
37
import contextlib2
37
- from lxml import etree
38
38
import numpy as np
39
39
import PIL .Image
40
40
import tensorflow as tf
45
45
46
46
flags = tf .app .flags
47
47
flags .DEFINE_string ('data_dir' , '' , 'Path to root directory to dataset.' )
48
+ flags .DEFINE_string ('images_dir' , 'JPEGImages' , 'Name of the directory contatining images' )
49
+ flags .DEFINE_string ('masks_dir' , 'Annotations' , 'Name of the directory contatining masks' )
48
50
flags .DEFINE_string ('output_dir' , '' , 'Path to directory to output TFRecords.' )
49
- flags .DEFINE_string ('image_dir' , 'JPEGImages' , 'Name of the directory contatining images' )
50
- flags .DEFINE_string ('annotations_dir' , 'Annotations' , 'Name of the directory contatining Annotations' )
51
51
flags .DEFINE_string ('label_map_path' , '' , 'Path to label map proto' )
52
52
flags .DEFINE_integer ('num_shards' , 1 , 'Number of TFRecord shards' )
53
53
FLAGS = flags .FLAGS
54
54
55
- # mask_pixel: dictionary containing class name and value for pixels belog to mask of each class
56
- # change as per your classes and labeling
57
- mask_pixel = {'speaker' :76 , 'cup' :26 }
55
+ def image_to_tf_example (img_path ,
56
+ mask_path ,
57
+ label_map_dict ,
58
+ filename ):
59
+ """Convert image and mask to tf.Example proto.
58
60
59
- def dict_to_tf_example (filename ,
60
- mask_path ,
61
- label_map_dict ,
62
- img_path ):
63
- """Convert XML derived dict to tf.Example proto.
64
-
65
- Notice that this function normalizes the bounding box coordinates provided
66
- by the raw data.
61
+ Note: that this function doesnt give correct output if an image contains
62
+ more than one object from same class
67
63
68
64
Args:
69
- filename: name of the image
65
+ img_path: String specifying subdirectory within the
66
+ dataset directory holding the actual image data.
70
67
mask_path: String path to PNG encoded mask.
71
68
label_map_dict: A map from string label names to integers ids.
72
- image_subdirectory: String specifying subdirectory within the
73
- dataset directory holding the actual image data.
69
+ filename: name of the image
74
70
75
71
76
72
Returns:
@@ -82,8 +78,8 @@ def dict_to_tf_example(filename,
82
78
with tf .gfile .GFile (img_path , 'rb' ) as fid :
83
79
encoded_jpg = fid .read ()
84
80
encoded_jpg_io = io .BytesIO (encoded_jpg )
85
- image = PIL .Image .open (encoded_jpg_io )
86
- width = np .asarray (image ).shape [1 ]
81
+ image = PIL .Image .open (encoded_jpg_io )
82
+ width = np .asarray (image ).shape [1 ]
87
83
height = np .asarray (image ).shape [0 ]
88
84
if image .format != 'JPEG' :
89
85
raise ValueError ('Image format not JPEG' )
@@ -103,15 +99,13 @@ def dict_to_tf_example(filename,
103
99
ymaxs = []
104
100
classes = []
105
101
classes_text = []
106
- truncated = []
107
- poses = []
108
- difficult_obj = []
109
- masks = []
110
-
111
- for k in list (mask_pixel .keys ()):
112
- class_name = k
113
- nonbackground_indices_x = np .any (mask_np == mask_pixel [class_name ], axis = 0 )
114
- nonbackground_indices_y = np .any (mask_np == mask_pixel [class_name ], axis = 1 )
102
+ encoded_mask_png_list = []
103
+
104
+ for key in label_map_dict .keys ():
105
+ class_name = key
106
+ pixel_val = int (label_map_dict [class_name ][1 ])
107
+ nonbackground_indices_x = np .any (mask_np == pixel_val , axis = 0 )
108
+ nonbackground_indices_y = np .any (mask_np == pixel_val , axis = 1 )
115
109
nonzero_x_indices = np .where (nonbackground_indices_x )
116
110
nonzero_y_indices = np .where (nonbackground_indices_y )
117
111
@@ -128,103 +122,94 @@ def dict_to_tf_example(filename,
128
122
ymaxs .append (ymax / height )
129
123
130
124
classes_text .append (class_name .encode ('utf8' ))
131
- classes .append (label_map_dict [class_name ])
125
+ classes .append (label_map_dict [class_name ][ 0 ] )
132
126
133
- mask_remapped = (mask_np == mask_pixel [class_name ]).astype (np .uint8 )
134
- masks .append (mask_remapped )
127
+ mask_remapped = (mask_np == pixel_val ).astype (np .uint8 )
128
+ img = PIL .Image .fromarray (mask_remapped )
129
+ output = io .BytesIO ()
130
+ img .save (output , format = 'PNG' )
131
+ encoded_mask_png_list .append (output .getvalue ())
135
132
136
133
feature_dict = {
137
- 'image/height' : dataset_util .int64_feature (height ),
138
- 'image/width' : dataset_util .int64_feature (width ),
139
- 'image/filename' : dataset_util .bytes_feature (
140
- filename .encode ('utf8' )),
141
- 'image/source_id' : dataset_util .bytes_feature (
142
- filename .encode ('utf8' )),
143
- 'image/key/sha256' : dataset_util .bytes_feature (key .encode ('utf8' )),
144
- 'image/encoded' : dataset_util .bytes_feature (encoded_jpg ),
145
- 'image/format' : dataset_util .bytes_feature ('jpeg' .encode ('utf8' )),
146
- 'image/object/bbox/xmin' : dataset_util .float_list_feature (xmins ),
147
- 'image/object/bbox/xmax' : dataset_util .float_list_feature (xmaxs ),
148
- 'image/object/bbox/ymin' : dataset_util .float_list_feature (ymins ),
149
- 'image/object/bbox/ymax' : dataset_util .float_list_feature (ymaxs ),
150
- 'image/object/class/text' : dataset_util .bytes_list_feature (classes_text ),
134
+ 'image/height' : dataset_util .int64_feature (height ),
135
+ 'image/width' : dataset_util .int64_feature (width ),
136
+ 'image/filename' : dataset_util .bytes_feature (filename .encode ('utf8' )),
137
+ 'image/source_id' : dataset_util .bytes_feature (filename .encode ('utf8' )),
138
+ 'image/key/sha256' : dataset_util .bytes_feature (key .encode ('utf8' )),
139
+ 'image/encoded' : dataset_util .bytes_feature (encoded_jpg ),
140
+ 'image/format' : dataset_util .bytes_feature ('jpeg' .encode ('utf8' )),
141
+ 'image/object/bbox/xmin' : dataset_util .float_list_feature (xmins ),
142
+ 'image/object/bbox/xmax' : dataset_util .float_list_feature (xmaxs ),
143
+ 'image/object/bbox/ymin' : dataset_util .float_list_feature (ymins ),
144
+ 'image/object/bbox/ymax' : dataset_util .float_list_feature (ymaxs ),
145
+ 'image/object/class/text' : dataset_util .bytes_list_feature (classes_text ),
151
146
'image/object/class/label' : dataset_util .int64_list_feature (classes ),
152
- 'image/object/difficult' : dataset_util .int64_list_feature (difficult_obj ),
153
- 'image/object/truncated' : dataset_util .int64_list_feature (truncated ),
154
- 'image/object/view' : dataset_util .bytes_list_feature (poses ),
155
- }
147
+ 'image/object/mask' : (dataset_util .bytes_list_feature (encoded_mask_png_list ))}
148
+ tf_data = tf .train .Example (features = tf .train .Features (feature = feature_dict ))
149
+ return tf_data
156
150
157
- encoded_mask_png_list = []
158
- for mask in masks :
159
- img = PIL .Image .fromarray (mask )
160
- output = io .BytesIO ()
161
- img .save (output , format = 'PNG' )
162
- encoded_mask_png_list .append (output .getvalue ())
163
- feature_dict ['image/object/mask' ] = (dataset_util .bytes_list_feature (encoded_mask_png_list ))
164
-
165
- example = tf .train .Example (features = tf .train .Features (feature = feature_dict ))
166
- return example
167
151
168
-
169
- def create_tf_record ( output_filename ,
170
- num_shards ,
152
+ def create_tf_record ( image_dir_path ,
153
+ mask_dir_path ,
154
+ output_dir_path ,
171
155
label_map_dict ,
172
- annotations_dir ,
173
- image_dir ,
174
- examples ):
175
- """Creates a TFRecord file from examples.
156
+ images_filename ,
157
+ num_shards ):
158
+ """Creates a TFRecord file from data.
176
159
177
160
Args:
178
- output_filename: Path to where output file is saved.
179
- num_shards: Number of shards for output file.
161
+ image_dir_path: Directory where image files are stored.
162
+ mask_dir_path: Directory where annotation files are stored.
163
+ output_dir_path: Path to where output file is saved.
180
164
label_map_dict: The label map dictionary.
181
- annotations_dir: Directory where annotation files are stored.
182
- image_dir: Directory where image files are stored.
183
- examples: Examples to parse and save to tf record.
165
+ images_filename: Examples to parse and save to tf record.
166
+ num_shards: Number of shards for output file.
184
167
"""
185
168
with contextlib2 .ExitStack () as tf_record_close_stack :
169
+ output_filename = os .path .join (output_dir_path , 'data' )
186
170
output_tfrecords = tf_record_creation_util .open_sharded_output_tfrecords (
187
171
tf_record_close_stack , output_filename , num_shards )
188
- for idx , example in enumerate (examples ):
172
+ for idx , filename in enumerate (images_filename ):
189
173
if idx % 100 == 0 :
190
- logging .info ('On image %d of %d' , idx , len (examples ))
191
- mask_path = os .path .join (annotations_dir , example + '.png' )
192
- image_path = os .path .join (image_dir , example + '.jpg' )
193
-
174
+ logging .info ('On image %d of %d' , idx , len (images_filename ))
175
+ mask_path = os .path .join (mask_dir_path , filename + '.png' )
176
+ image_path = os .path .join (image_dir_path , filename + '.jpg' )
194
177
try :
195
- tf_example = dict_to_tf_example ( example ,
178
+ tf_example = image_to_tf_example ( image_path ,
196
179
mask_path ,
197
180
label_map_dict ,
198
- image_path )
181
+ filename )
199
182
if tf_example :
200
183
shard_idx = idx % num_shards
201
184
output_tfrecords [shard_idx ].write (tf_example .SerializeToString ())
202
- print ( " done" )
185
+ logging . info ( ' done' )
203
186
except ValueError :
204
- logging .warning ('Invalid example: %s, ignoring.' , xml_path )
187
+ logging .warning ('Invalid example: %s, ignoring.' , image_path )
205
188
206
189
def main (_ ):
207
- data_dir = FLAGS .data_dir
208
- train_output_path = FLAGS .output_dir
209
- image_dir = os .path .join (data_dir , FLAGS .image_dir )
210
- annotations_dir = os .path .join (data_dir , FLAGS .annotations_dir )
211
- label_map_dict = label_map_util .get_label_map_dict (FLAGS .label_map_path )
190
+ data_dir_path = FLAGS .data_dir
191
+ images_dir_path = os .path .join (data_dir_path , FLAGS .images_dir )
192
+ masks_dir_path = os .path .join (data_dir_path , FLAGS .masks_dir )
193
+ tfrecord_dir_path = FLAGS .output_dir
194
+ label_map_dict = label_map_util .get_label_map_dict (FLAGS .label_map_path )
195
+ for key in label_map_dict .keys ():
196
+ label_map_dict [key ] = [label_map_dict [key ], key [- 3 :]]
197
+ label_map_dict [key [:- 3 ]] = label_map_dict .pop (key )
212
198
213
199
logging .info ('Reading from dataset.' )
214
- examples_list = os .listdir (image_dir )
215
- for el in examples_list :
216
- if el [- 3 :] != 'jpg' :
217
- del examples_list [examples_list .index (el )]
218
- for el in examples_list :
219
- examples_list [examples_list .index (el )] = el [0 :- 4 ]
220
-
221
- create_tf_record (train_output_path ,
222
- FLAGS .num_shards ,
223
- label_map_dict ,
224
- annotations_dir ,
225
- image_dir ,
226
- examples_list )
227
-
200
+ images_filename = os .listdir (images_dir_path )
201
+ for filename in images_filename :
202
+ if filename [- 3 :] != 'jpg' :
203
+ del images_filename [images_filename .index (filename )]
204
+ for filename in images_filename :
205
+ images_filename [images_filename .index (filename )] = filename [0 :- 4 ]
206
+
207
+ create_tf_record (images_dir_path ,
208
+ masks_dir_path ,
209
+ tfrecord_dir_path ,
210
+ label_map_dict ,
211
+ images_filename ,
212
+ FLAGS .num_shards )
228
213
229
214
if __name__ == '__main__' :
230
215
tf .app .run ()
0 commit comments