Skip to content

Commit c579811

Browse files
authored
Create ShuffleNet.py
1 parent b5edf6e commit c579811

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

CNNs/ShuffleNet.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
implement a shuffleNet by pytorch
3+
"""
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from torch.autograd import Variable
8+
9+
dtype = torch.FloatTensor
10+
11+
def shuffle_channels(x, groups):
12+
"""shuffle channels of a 4-D Tensor"""
13+
batch_size, channels, height, width = x.size()
14+
assert channels % groups == 0
15+
channels_per_group = channels // groups
16+
# split into groups
17+
x = x.view(batch_size, groups, channels_per_group,
18+
height, width)
19+
# transpose 1, 2 axis
20+
x = x.transpose(1, 2).contiguous()
21+
# reshape into orignal
22+
x = x.view(batch_size, channels, height, width)
23+
return x
24+
25+
class ShuffleNetUnitA(nn.Module):
26+
"""ShuffleNet unit for stride=1"""
27+
def __init__(self, in_channels, out_channels, groups=3):
28+
super(ShuffleNetUnitA, self).__init__()
29+
assert in_channels == out_channels
30+
assert out_channels % 4 == 0
31+
bottleneck_channels = out_channels // 4
32+
self.groups = groups
33+
self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
34+
1, groups=groups, stride=1)
35+
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
36+
self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
37+
bottleneck_channels,
38+
3, padding=1, stride=1,
39+
groups=bottleneck_channels)
40+
self.bn4 = nn.BatchNorm2d(bottleneck_channels)
41+
self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
42+
1, stride=1, groups=groups)
43+
self.bn6 = nn.BatchNorm2d(out_channels)
44+
45+
def forward(self, x):
46+
out = self.group_conv1(x)
47+
out = F.relu(self.bn2(out))
48+
out = shuffle_channels(out, groups=self.groups)
49+
out = self.depthwise_conv3(out)
50+
out = self.bn4(out)
51+
out = self.group_conv5(out)
52+
out = self.bn6(out)
53+
out = F.relu(x + out)
54+
return out
55+
56+
class ShuffleNetUnitB(nn.Module):
57+
"""ShuffleNet unit for stride=2"""
58+
def __init__(self, in_channels, out_channels, groups=3):
59+
super(ShuffleNetUnitB, self).__init__()
60+
out_channels -= in_channels
61+
assert out_channels % 4 == 0
62+
bottleneck_channels = out_channels // 4
63+
self.groups = groups
64+
self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
65+
1, groups=groups, stride=1)
66+
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
67+
self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
68+
bottleneck_channels,
69+
3, padding=1, stride=2,
70+
groups=bottleneck_channels)
71+
self.bn4 = nn.BatchNorm2d(bottleneck_channels)
72+
self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
73+
1, stride=1, groups=groups)
74+
self.bn6 = nn.BatchNorm2d(out_channels)
75+
76+
def forward(self, x):
77+
out = self.group_conv1(x)
78+
out = F.relu(self.bn2(out))
79+
out = shuffle_channels(out, groups=self.groups)
80+
out = self.depthwise_conv3(out)
81+
out = self.bn4(out)
82+
out = self.group_conv5(out)
83+
out = self.bn6(out)
84+
x = F.avg_pool2d(x, 3, stride=2, padding=1)
85+
out = F.relu(torch.cat([x, out], dim=1))
86+
return out
87+
88+
class ShuffleNet(nn.Module):
89+
"""ShuffleNet for groups=3"""
90+
def __init__(self, groups=3, in_channels=3, num_classes=1000):
91+
super(ShuffleNet, self).__init__()
92+
93+
self.conv1 = nn.Conv2d(in_channels, 24, 3, stride=2, padding=1)
94+
stage2_seq = [ShuffleNetUnitB(24, 240, groups=3)] + \
95+
[ShuffleNetUnitA(240, 240, groups=3) for i in range(3)]
96+
self.stage2 = nn.Sequential(*stage2_seq)
97+
stage3_seq = [ShuffleNetUnitB(240, 480, groups=3)] + \
98+
[ShuffleNetUnitA(480, 480, groups=3) for i in range(7)]
99+
self.stage3 = nn.Sequential(*stage3_seq)
100+
stage4_seq = [ShuffleNetUnitB(480, 960, groups=3)] + \
101+
[ShuffleNetUnitA(960, 960, groups=3) for i in range(3)]
102+
self.stage4 = nn.Sequential(*stage4_seq)
103+
self.fc = nn.Linear(960, num_classes)
104+
105+
def forward(self, x):
106+
net = self.conv1(x)
107+
net = F.max_pool2d(net, 3, stride=2, padding=1)
108+
net = self.stage2(net)
109+
net = self.stage3(net)
110+
net = self.stage4(net)
111+
net = F.avg_pool2d(net, 7)
112+
net = net.view(net.size(0), -1)
113+
net = self.fc(net)
114+
logits = F.softmax(net)
115+
return logits
116+
117+
if __name__ == "__main__":
118+
x = Variable(torch.randn([32, 3, 224, 224]).type(dtype),
119+
requires_grad=False)
120+
shuffleNet = ShuffleNet()
121+
out = shuffleNet(x)
122+
print(out.size())

0 commit comments

Comments
 (0)