Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 95144d6

Browse files
add shufflenetv2
1 parent e7bc360 commit 95144d6

File tree

1 file changed

+243
-0
lines changed

1 file changed

+243
-0
lines changed

‎CNNs/shufflenet_v2.py‎

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
"""
2+
The implement of shufflenet_v2 by Keras
3+
"""
4+
5+
import tensorflow as tf
6+
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D
7+
from tensorflow.keras.layers import MaxPool2D, GlobalAveragePooling2D, Dense
8+
from tensorflow.keras.layers import BatchNormalization, Activation
9+
10+
11+
def channle_shuffle(inputs, group):
12+
"""Shuffle the channel
13+
Args:
14+
inputs: 4D Tensor
15+
group: int, number of groups
16+
Returns:
17+
Shuffled 4D Tensor
18+
"""
19+
in_shape = inputs.get_shape().as_list()
20+
h, w, in_channel = in_shape[1:]
21+
assert in_channel % group == 0
22+
l = tf.reshape(inputs, [-1, h, w, in_channel // group, group])
23+
l = tf.transpose(l, [0, 1, 2, 4, 3])
24+
l = tf.reshape(l, [-1, h, w, in_channel])
25+
26+
return l
27+
28+
class Conv2D_BN_ReLU(tf.keras.Model):
29+
"""Conv2D -> BN -> ReLU"""
30+
def __init__(self, channel, kernel_size=1, stride=1):
31+
super(Conv2D_BN_ReLU, self).__init__()
32+
33+
self.conv = Conv2D(channel, kernel_size, strides=stride,
34+
padding="SAME", use_bias=False)
35+
self.bn = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)
36+
self.relu = Activation("relu")
37+
38+
def call(self, inputs, training=True):
39+
x = self.conv(inputs)
40+
x = self.bn(x, training=training)
41+
x = self.relu(x)
42+
return x
43+
44+
class DepthwiseConv2D_BN(tf.keras.Model):
45+
"""DepthwiseConv2D -> BN"""
46+
def __init__(self, kernel_size=3, stride=1):
47+
super(DepthwiseConv2D_BN, self).__init__()
48+
49+
self.dconv = DepthwiseConv2D(kernel_size, strides=stride,
50+
depth_multiplier=1,
51+
padding="SAME", use_bias=False)
52+
self.bn = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)
53+
54+
def call(self, inputs, training=True):
55+
x = self.dconv(inputs)
56+
x = self.bn(x, training=training)
57+
return x
58+
59+
60+
class ShufflenetUnit1(tf.keras.Model):
61+
def __init__(self, out_channel):
62+
"""The unit of shufflenetv2 for stride=1
63+
Args:
64+
out_channel: int, number of channels
65+
"""
66+
super(ShufflenetUnit1, self).__init__()
67+
68+
assert out_channel % 2 == 0
69+
self.out_channel = out_channel
70+
71+
self.conv1_bn_relu = Conv2D_BN_ReLU(out_channel // 2, 1, 1)
72+
self.dconv_bn = DepthwiseConv2D_BN(3, 1)
73+
self.conv2_bn_relu = Conv2D_BN_ReLU(out_channel // 2, 1, 1)
74+
75+
def call(self, inputs, training=False):
76+
# split the channel
77+
shortcut, x = tf.split(inputs, 2, axis=3)
78+
79+
x = self.conv1_bn_relu(x, training=training)
80+
x = self.dconv_bn(x, training=training)
81+
x = self.conv2_bn_relu(x, training=training)
82+
83+
x = tf.concat([shortcut, x], axis=3)
84+
x = channle_shuffle(x, 2)
85+
return x
86+
87+
class ShufflenetUnit2(tf.keras.Model):
88+
"""The unit of shufflenetv2 for stride=2"""
89+
def __init__(self, in_channel, out_channel):
90+
super(ShufflenetUnit2, self).__init__()
91+
92+
assert out_channel % 2 == 0
93+
self.in_channel = in_channel
94+
self.out_channel = out_channel
95+
96+
self.conv1_bn_relu = Conv2D_BN_ReLU(out_channel // 2, 1, 1)
97+
self.dconv_bn = DepthwiseConv2D_BN(3, 2)
98+
self.conv2_bn_relu = Conv2D_BN_ReLU(out_channel - in_channel, 1, 1)
99+
100+
# for shortcut
101+
self.shortcut_dconv_bn = DepthwiseConv2D_BN(3, 2)
102+
self.shortcut_conv_bn_relu = Conv2D_BN_ReLU(in_channel, 1, 1)
103+
104+
def call(self, inputs, training=False):
105+
shortcut, x = inputs, inputs
106+
107+
x = self.conv1_bn_relu(x, training=training)
108+
x = self.dconv_bn(x, training=training)
109+
x = self.conv2_bn_relu(x, training=training)
110+
111+
shortcut = self.shortcut_dconv_bn(shortcut, training=training)
112+
shortcut = self.shortcut_conv_bn_relu(shortcut, training=training)
113+
114+
x = tf.concat([shortcut, x], axis=3)
115+
x = channle_shuffle(x, 2)
116+
return x
117+
118+
class ShufflenetStage(tf.keras.Model):
119+
"""The stage of shufflenet"""
120+
def __init__(self, in_channel, out_channel, num_blocks):
121+
super(ShufflenetStage, self).__init__()
122+
123+
self.in_channel = in_channel
124+
self.out_channel = out_channel
125+
126+
self.ops = []
127+
for i in range(num_blocks):
128+
if i == 0:
129+
op = ShufflenetUnit2(in_channel, out_channel)
130+
else:
131+
op = ShufflenetUnit1(out_channel)
132+
self.ops.append(op)
133+
134+
def call(self, inputs, training=False):
135+
x = inputs
136+
for op in self.ops:
137+
x = op(x, training=training)
138+
return x
139+
140+
141+
class ShuffleNetv2(tf.keras.Model):
142+
"""Shufflenetv2"""
143+
def __init__(self, num_classes, first_channel=24, channels_per_stage=(116, 232, 464)):
144+
super(ShuffleNetv2, self).__init__()
145+
146+
self.num_classes = num_classes
147+
148+
self.conv1_bn_relu = Conv2D_BN_ReLU(first_channel, 3, 2)
149+
self.pool1 = MaxPool2D(3, strides=2, padding="SAME")
150+
self.stage2 = ShufflenetStage(first_channel, channels_per_stage[0], 4)
151+
self.stage3 = ShufflenetStage(channels_per_stage[0], channels_per_stage[1], 8)
152+
self.stage4 = ShufflenetStage(channels_per_stage[1], channels_per_stage[2], 4)
153+
self.conv5_bn_relu = Conv2D_BN_ReLU(1024, 1, 1)
154+
self.gap = GlobalAveragePooling2D()
155+
self.linear = Dense(num_classes)
156+
157+
def call(self, inputs, training=False):
158+
x = self.conv1_bn_relu(inputs, training=training)
159+
x = self.pool1(x)
160+
x = self.stage2(x, training=training)
161+
x = self.stage3(x, training=training)
162+
x = self.stage4(x, training=training)
163+
x = self.conv5_bn_relu(x, training=training)
164+
x = self.gap(x)
165+
x = self.linear(x)
166+
return x
167+
168+
169+
if __name__ =="__main__":
170+
"""
171+
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
172+
173+
model = ShuffleNetv2(1000)
174+
outputs = model(inputs)
175+
176+
print(model.summary())
177+
178+
with tf.Session() as sess:
179+
pass
180+
181+
182+
vars = []
183+
for v in tf.global_variables():
184+
185+
vars.append((v.name, v))
186+
print(v.name)
187+
print(len(vars))
188+
189+
190+
import numpy as np
191+
192+
path = "C:/models/ShuffleNetV2-1x.npz"
193+
weights = np.load(path)
194+
np_vars = []
195+
for k in weights:
196+
k_ = k.replace("beta", "gbeta")
197+
k_ = k_.replace("/dconv", "/conv10_dconv")
198+
k_ = k_.replace("shortcut_dconv", "shortcut_a_dconv")
199+
k_ = k_.replace("conv5", "su_conv5")
200+
k_ = k_.replace("linear", "t_linear")
201+
np_vars.append((k_, weights[k]))
202+
np_vars.sort(key=lambda x: x[0])
203+
204+
for k, _ in np_vars:
205+
print(k)
206+
207+
saver = tf.train.Saver(tf.global_variables())
208+
with tf.Session() as sess:
209+
sess.run(tf.global_variables_initializer())
210+
211+
assign_ops = []
212+
for id in range(len(vars)):
213+
print(vars[id][0], np_vars[id][0])
214+
assign_ops.append(tf.assign(vars[id][1], np_vars[id][1]))
215+
216+
sess.run(assign_ops)
217+
saver.save(sess, "./models/shufflene_v2_1.0.ckpt")
218+
219+
model.save("./models/shufflenet_v2_1.0.hdf5")
220+
221+
"""
222+
223+
import numpy as np
224+
from tensorflow.keras.preprocessing import image
225+
from tensorflow.keras.applications.densenet import preprocess_input, decode_predictions
226+
227+
img_path = './images/cat.jpg'
228+
img = image.load_img(img_path, target_size=(224, 224))
229+
x = image.img_to_array(img)
230+
x = np.expand_dims(x, axis=0)
231+
x = preprocess_input(x)
232+
233+
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
234+
model = ShuffleNetv2(1000)
235+
outputs = model(inputs, training=False)
236+
outputs = tf.nn.softmax(outputs)
237+
238+
saver = tf.train.Saver()
239+
with tf.Session() as sess:
240+
saver.restore(sess, "./models/shufflene_v2_1.0.ckpt")
241+
preds = sess.run(outputs, feed_dict={inputs: x})
242+
print(decode_predictions(preds, top=3)[0])
243+

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /