44
55# Food 101 数据集
66
7- def _parse_function (filename , label ):
8- image_string = tf .read_file (filename )
9- image_decoded = tf .image .decode_jpeg (image_string )
10- image_resized = tf .image .resize_images (image_decoded , [227 , 227 , 3 ])
11- return image_resized , label
7+ def _parse_function (serialized_example_test ):
8+ features = tf .parse_single_example (
9+ serialized_example_test ,
10+ features = {
11+ 'label' : tf .FixedLenFeature ([], tf .int64 ),
12+ 'image_raw' : tf .FixedLenFeature ([], tf .string ),
13+ }
14+ )
15+ img_train = features ['image_raw' ]
16+ # image_decoded = tf.decode_raw(img_train, tf.uint8)
17+ image_decoded = tf .image .decode_image (img_train , channels = 3 )
18+ image_resized = tf .image .resize_images (image_decoded , [227 , 227 ])
19+ labels = tf .cast (features ['label' ], tf .int64 )
20+ labels = tf .one_hot (labels , 101 )
21+ shape = tf .cast ([227 , 227 ], tf .int32 )
22+ return image_resized , labels
1223
1324
1425def alex_net ():
1526 model = tf .keras .Sequential ()
1627 model .add (layers .Conv2D (filters = 96 , kernel_size = (11 , 11 ), strides = (4 , 4 ), padding = 'valid' ,
17- activation = 'relu' ))
28+ activation = 'relu' , input_shape = ( 227 , 227 , 3 ) ))
1829 model .add (layers .MaxPool2D (pool_size = (3 , 3 ), strides = (2 , 2 ), padding = 'valid' ))
1930 model .add (layers .BatchNormalization ())
2031 model .add (layers .Conv2D (filters = 256 , kernel_size = (5 , 5 ), strides = (1 , 1 ), padding = 'same' ,
@@ -33,24 +44,34 @@ def alex_net():
3344 model .add (layers .Dropout (0.5 ))
3445 model .add (layers .Dense (4096 , activation = tf .keras .activations .relu ))
3546 model .add (layers .Dropout (0.5 ))
36- model .add (layers .Dense (1000 , activation = tf .keras .activations .softmax ))
47+ model .add (layers .Dense (101 , activation = tf .keras .activations .softmax ))
3748 return model
3849
3950
40- filenames = ["/media/data/oldcopy/PythonProject/Food101/TFRecord/train.tfrecords" ]
51+ filenames = ["/home/heolis/Data/food-101/TFRecord/train0.tfrecords" ,
52+ "/home/heolis/Data/food-101/TFRecord/train1.tfrecords" ,
53+ "/home/heolis/Data/food-101/TFRecord/train2.tfrecords" ,
54+ "/home/heolis/Data/food-101/TFRecord/train3.tfrecords" ]
4155trainSet = tf .data .TFRecordDataset (filenames )
4256trainSet = trainSet .map (_parse_function )
43- trainSet = trainSet .repeat ()
57+ trainSet = trainSet .repeat (10 )
4458trainSet = trainSet .batch (32 )
59+ iterator_train = trainSet .make_one_shot_iterator ()
4560
46- filenames = ["/media/data/oldcopy/PythonProject/Food101/ TFRecord/test .tfrecords" ]
61+ filenames = ["/home/heolis/Data/food-101/ TFRecord/train3 .tfrecords" ]
4762testSet = tf .data .TFRecordDataset (filenames )
4863testSet = testSet .map (_parse_function )
49- testSet = testSet .repeat ()
64+ testSet = testSet .repeat (10 )
5065testSet = testSet .batch (32 )
66+ iterator_test = testSet .make_one_shot_iterator ()
5167
5268model = alex_net ()
53- model .fit (trainSet , epochs = 10 , batch_size = 32 , validation_data = testSet )
69+ 70+ model .compile (optimizer = tf .train .GradientDescentOptimizer (0.03 ),
71+ loss = tf .keras .losses .categorical_crossentropy ,
72+ metrics = [tf .keras .metrics .categorical_accuracy ])
73+ 74+ model .fit (iterator_train , epochs = 10 , validation_data = testSet , steps_per_epoch = 10000 )
5475
5576loss , accuracy = model .evaluate (testSet )
5677print ("loss:%f, accuracy:%f" % (loss , accuracy ))
0 commit comments