Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c579811

Browse files
Create ShuffleNet.py
1 parent b5edf6e commit c579811

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

‎CNNs/ShuffleNet.py‎

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
implement a shuffleNet by pytorch
3+
"""
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from torch.autograd import Variable
8+
9+
dtype = torch.FloatTensor
10+
11+
def shuffle_channels(x, groups):
12+
"""shuffle channels of a 4-D Tensor"""
13+
batch_size, channels, height, width = x.size()
14+
assert channels % groups == 0
15+
channels_per_group = channels // groups
16+
# split into groups
17+
x = x.view(batch_size, groups, channels_per_group,
18+
height, width)
19+
# transpose 1, 2 axis
20+
x = x.transpose(1, 2).contiguous()
21+
# reshape into orignal
22+
x = x.view(batch_size, channels, height, width)
23+
return x
24+
25+
class ShuffleNetUnitA(nn.Module):
26+
"""ShuffleNet unit for stride=1"""
27+
def __init__(self, in_channels, out_channels, groups=3):
28+
super(ShuffleNetUnitA, self).__init__()
29+
assert in_channels == out_channels
30+
assert out_channels % 4 == 0
31+
bottleneck_channels = out_channels // 4
32+
self.groups = groups
33+
self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
34+
1, groups=groups, stride=1)
35+
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
36+
self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
37+
bottleneck_channels,
38+
3, padding=1, stride=1,
39+
groups=bottleneck_channels)
40+
self.bn4 = nn.BatchNorm2d(bottleneck_channels)
41+
self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
42+
1, stride=1, groups=groups)
43+
self.bn6 = nn.BatchNorm2d(out_channels)
44+
45+
def forward(self, x):
46+
out = self.group_conv1(x)
47+
out = F.relu(self.bn2(out))
48+
out = shuffle_channels(out, groups=self.groups)
49+
out = self.depthwise_conv3(out)
50+
out = self.bn4(out)
51+
out = self.group_conv5(out)
52+
out = self.bn6(out)
53+
out = F.relu(x + out)
54+
return out
55+
56+
class ShuffleNetUnitB(nn.Module):
57+
"""ShuffleNet unit for stride=2"""
58+
def __init__(self, in_channels, out_channels, groups=3):
59+
super(ShuffleNetUnitB, self).__init__()
60+
out_channels -= in_channels
61+
assert out_channels % 4 == 0
62+
bottleneck_channels = out_channels // 4
63+
self.groups = groups
64+
self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
65+
1, groups=groups, stride=1)
66+
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
67+
self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
68+
bottleneck_channels,
69+
3, padding=1, stride=2,
70+
groups=bottleneck_channels)
71+
self.bn4 = nn.BatchNorm2d(bottleneck_channels)
72+
self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
73+
1, stride=1, groups=groups)
74+
self.bn6 = nn.BatchNorm2d(out_channels)
75+
76+
def forward(self, x):
77+
out = self.group_conv1(x)
78+
out = F.relu(self.bn2(out))
79+
out = shuffle_channels(out, groups=self.groups)
80+
out = self.depthwise_conv3(out)
81+
out = self.bn4(out)
82+
out = self.group_conv5(out)
83+
out = self.bn6(out)
84+
x = F.avg_pool2d(x, 3, stride=2, padding=1)
85+
out = F.relu(torch.cat([x, out], dim=1))
86+
return out
87+
88+
class ShuffleNet(nn.Module):
89+
"""ShuffleNet for groups=3"""
90+
def __init__(self, groups=3, in_channels=3, num_classes=1000):
91+
super(ShuffleNet, self).__init__()
92+
93+
self.conv1 = nn.Conv2d(in_channels, 24, 3, stride=2, padding=1)
94+
stage2_seq = [ShuffleNetUnitB(24, 240, groups=3)] + \
95+
[ShuffleNetUnitA(240, 240, groups=3) for i in range(3)]
96+
self.stage2 = nn.Sequential(*stage2_seq)
97+
stage3_seq = [ShuffleNetUnitB(240, 480, groups=3)] + \
98+
[ShuffleNetUnitA(480, 480, groups=3) for i in range(7)]
99+
self.stage3 = nn.Sequential(*stage3_seq)
100+
stage4_seq = [ShuffleNetUnitB(480, 960, groups=3)] + \
101+
[ShuffleNetUnitA(960, 960, groups=3) for i in range(3)]
102+
self.stage4 = nn.Sequential(*stage4_seq)
103+
self.fc = nn.Linear(960, num_classes)
104+
105+
def forward(self, x):
106+
net = self.conv1(x)
107+
net = F.max_pool2d(net, 3, stride=2, padding=1)
108+
net = self.stage2(net)
109+
net = self.stage3(net)
110+
net = self.stage4(net)
111+
net = F.avg_pool2d(net, 7)
112+
net = net.view(net.size(0), -1)
113+
net = self.fc(net)
114+
logits = F.softmax(net)
115+
return logits
116+
117+
if __name__ == "__main__":
118+
x = Variable(torch.randn([32, 3, 224, 224]).type(dtype),
119+
requires_grad=False)
120+
shuffleNet = ShuffleNet()
121+
out = shuffleNet(x)
122+
print(out.size())

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /