Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 159213b

Browse files
committed
🔥update
1 parent 61b6904 commit 159213b

File tree

2 files changed

+55
-27
lines changed

2 files changed

+55
-27
lines changed

‎AlexNet/AlexNet.py‎

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,28 @@
44

55
# Food 101 数据集
66

7-
def _parse_function(filename, label):
8-
image_string = tf.read_file(filename)
9-
image_decoded = tf.image.decode_jpeg(image_string)
10-
image_resized = tf.image.resize_images(image_decoded, [227, 227, 3])
11-
return image_resized, label
7+
def _parse_function(serialized_example_test):
8+
features = tf.parse_single_example(
9+
serialized_example_test,
10+
features={
11+
'label': tf.FixedLenFeature([], tf.int64),
12+
'image_raw': tf.FixedLenFeature([], tf.string),
13+
}
14+
)
15+
img_train = features['image_raw']
16+
# image_decoded = tf.decode_raw(img_train, tf.uint8)
17+
image_decoded = tf.image.decode_image(img_train, channels=3)
18+
image_resized = tf.image.resize_images(image_decoded, [227, 227])
19+
labels = tf.cast(features['label'], tf.int64)
20+
labels = tf.one_hot(labels, 101)
21+
shape = tf.cast([227, 227], tf.int32)
22+
return image_resized, labels
1223

1324

1425
def alex_net():
1526
model = tf.keras.Sequential()
1627
model.add(layers.Conv2D(filters=96, kernel_size=(11, 11), strides=(4, 4), padding='valid',
17-
activation='relu'))
28+
activation='relu', input_shape=(227, 227, 3)))
1829
model.add(layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='valid'))
1930
model.add(layers.BatchNormalization())
2031
model.add(layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1), padding='same',
@@ -33,24 +44,34 @@ def alex_net():
3344
model.add(layers.Dropout(0.5))
3445
model.add(layers.Dense(4096, activation=tf.keras.activations.relu))
3546
model.add(layers.Dropout(0.5))
36-
model.add(layers.Dense(1000, activation=tf.keras.activations.softmax))
47+
model.add(layers.Dense(101, activation=tf.keras.activations.softmax))
3748
return model
3849

3950

40-
filenames = ["/media/data/oldcopy/PythonProject/Food101/TFRecord/train.tfrecords"]
51+
filenames = ["/home/heolis/Data/food-101/TFRecord/train0.tfrecords",
52+
"/home/heolis/Data/food-101/TFRecord/train1.tfrecords",
53+
"/home/heolis/Data/food-101/TFRecord/train2.tfrecords",
54+
"/home/heolis/Data/food-101/TFRecord/train3.tfrecords"]
4155
trainSet = tf.data.TFRecordDataset(filenames)
4256
trainSet = trainSet.map(_parse_function)
43-
trainSet = trainSet.repeat()
57+
trainSet = trainSet.repeat(10)
4458
trainSet = trainSet.batch(32)
59+
iterator_train = trainSet.make_one_shot_iterator()
4560

46-
filenames = ["/media/data/oldcopy/PythonProject/Food101/TFRecord/test.tfrecords"]
61+
filenames = ["/home/heolis/Data/food-101/TFRecord/train3.tfrecords"]
4762
testSet = tf.data.TFRecordDataset(filenames)
4863
testSet = testSet.map(_parse_function)
49-
testSet = testSet.repeat()
64+
testSet = testSet.repeat(10)
5065
testSet = testSet.batch(32)
66+
iterator_test = testSet.make_one_shot_iterator()
5167

5268
model = alex_net()
53-
model.fit(trainSet, epochs=10, batch_size=32, validation_data=testSet)
69+
70+
model.compile(optimizer=tf.train.GradientDescentOptimizer(0.03),
71+
loss=tf.keras.losses.categorical_crossentropy,
72+
metrics=[tf.keras.metrics.categorical_accuracy])
73+
74+
model.fit(iterator_train, epochs=10, validation_data=testSet, steps_per_epoch=10000)
5475

5576
loss, accuracy = model.evaluate(testSet)
5677
print("loss:%f, accuracy:%f" % (loss, accuracy))

‎AlexNet/CreateTFRecord.py‎

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tensorflow as tf
44
import random
55
import os
6+
import cv2
67
from tqdm import tqdm
78
import threading
89

@@ -12,17 +13,17 @@
1213

1314

1415
def load_files():
15-
f = open("/media/heolis/967EC257F5104FE6/oldcopy/PythonProject/Food101/food-101/meta/classes.txt")
16+
f = open("/home/heolis/Data/food-101/meta/classes.txt")
1617
lines = f.readlines()
1718
for line in lines:
1819
classes.append(line.strip('\n'))
1920
f.close()
20-
f = open("/media/heolis/967EC257F5104FE6/oldcopy/PythonProject/Food101/food-101/meta/train.txt")
21+
f = open("/home/heolis/Data/food-101/meta/train.txt")
2122
lines = f.readlines()
2223
for line in lines:
2324
trainPaths.append(line.strip('\n'))
2425
f.close()
25-
f = open("/media/heolis/967EC257F5104FE6/oldcopy/PythonProject/Food101/food-101/meta/test.txt")
26+
f = open("/home/heolis/Data/food-101/meta/test.txt")
2627
lines = f.readlines()
2728
for line in lines:
2829
testPaths.append(line.strip('\n'))
@@ -51,22 +52,26 @@ def probuf(label, image_raw):
5152
return example.SerializeToString()
5253

5354

54-
def writerRecord(save_path, HOME_PATH, coord, t_id):
55+
def writerRecord(save_path, HOME_PATH, coord=None, t_id=0):
5556
writerTrain = tf.python_io.TFRecordWriter(os.path.join(save_path, "train" + str(t_id) + ".tfrecords"))
5657
writerTest = tf.python_io.TFRecordWriter(os.path.join(save_path, "test" + str(t_id) + ".tfrecords"))
5758
randIndexTrain = random.sample(range(0, len(trainPaths)), len(trainPaths))
5859

59-
size = int(len(randIndexTrain) / 1)
60+
size = int(len(randIndexTrain) / 4)
6061
randIndexTrain = randIndexTrain[t_id * size: (t_id + 1) * size]
6162

6263
randIndexTest = random.sample(range(0, len(testPaths)), len(testPaths))
63-
size = int(len(randIndexTest) / 1)
64+
size = int(len(randIndexTest) / 4)
6465
randIndexTest = randIndexTest[t_id * size: (t_id + 1) * size]
6566

6667
sess = tf.Session()
6768
for i in tqdm(randIndexTrain, "train:"):
68-
image_string = tf.read_file(os.path.join(HOME_PATH, trainPaths[i] + ".jpg"))
69-
image_string = sess.run(image_string)
69+
# image_string = tf.read_file(os.path.join(HOME_PATH, trainPaths[i] + ".jpg"))
70+
# image_string = sess.run(image_string)
71+
72+
image_string = cv2.imread(os.path.join(HOME_PATH, trainPaths[i] + ".jpg"))
73+
image_string = image_string.tostring()
74+
7075
label = trainPaths[i].split("/")[0]
7176
label = transform_label(label)
7277
writerTrain.write(probuf(label, image_string))
@@ -78,16 +83,18 @@ def writerRecord(save_path, HOME_PATH, coord, t_id):
7883
label = transform_label(label)
7984
writerTest.write(probuf(label, image_string))
8085
writerTest.close()
81-
coord.request_stop()
86+
87+
# coord.request_stop()
8288

8389

8490
if __name__ == '__main__':
8591
# sess = tf.InteractiveSession()
8692
coord = tf.train.Coordinator()
8793
load_files()
88-
threads = [threading.Thread(target=writerRecord, args=("/media/heolis/967EC257F5104FE6/oldcopy/PythonProject/Food101/TFRecord",
89-
"/media/heolis/967EC257F5104FE6/oldcopy/PythonProject/Food101/food-101/images",
90-
coord, i)) for i in range(1)]
91-
for t in threads:
92-
t.start()
93-
coord.join(threads)
94+
# threads = [threading.Thread(target=writerRecord, args=("/home/heolis/Data/food-101/TFRecord",
95+
# "/home/heolis/Data/food-101/images",
96+
# coord, i)) for i in range(4)]
97+
# for t in threads:
98+
# t.start()
99+
# coord.join(threads)
100+
writerRecord("/home/heolis/Data/food-101/TFRecord", "/home/heolis/Data/food-101/images")

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /