|  | 
|  | 1 | +""" | 
|  | 2 | +implement a shuffleNet by pytorch | 
|  | 3 | +""" | 
|  | 4 | +import torch | 
|  | 5 | +import torch.nn as nn | 
|  | 6 | +import torch.nn.functional as F | 
|  | 7 | +from torch.autograd import Variable | 
|  | 8 | + | 
|  | 9 | +dtype = torch.FloatTensor | 
|  | 10 | + | 
|  | 11 | +def shuffle_channels(x, groups): | 
|  | 12 | +    """shuffle channels of a 4-D Tensor""" | 
|  | 13 | +    batch_size, channels, height, width = x.size() | 
|  | 14 | +    assert channels % groups == 0 | 
|  | 15 | +    channels_per_group = channels // groups | 
|  | 16 | +    # split into groups | 
|  | 17 | +    x = x.view(batch_size, groups, channels_per_group, | 
|  | 18 | +               height, width) | 
|  | 19 | +    # transpose 1, 2 axis | 
|  | 20 | +    x = x.transpose(1, 2).contiguous() | 
|  | 21 | +    # reshape into orignal | 
|  | 22 | +    x = x.view(batch_size, channels, height, width) | 
|  | 23 | +    return x | 
|  | 24 | + | 
|  | 25 | +class ShuffleNetUnitA(nn.Module): | 
|  | 26 | +    """ShuffleNet unit for stride=1""" | 
|  | 27 | +    def __init__(self, in_channels, out_channels, groups=3): | 
|  | 28 | +        super(ShuffleNetUnitA, self).__init__() | 
|  | 29 | +        assert in_channels == out_channels | 
|  | 30 | +        assert out_channels % 4 == 0 | 
|  | 31 | +        bottleneck_channels = out_channels // 4 | 
|  | 32 | +        self.groups = groups | 
|  | 33 | +        self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels, | 
|  | 34 | +                                        1, groups=groups, stride=1) | 
|  | 35 | +        self.bn2 = nn.BatchNorm2d(bottleneck_channels) | 
|  | 36 | +        self.depthwise_conv3 = nn.Conv2d(bottleneck_channels, | 
|  | 37 | +                                         bottleneck_channels, | 
|  | 38 | +                                         3, padding=1, stride=1, | 
|  | 39 | +                                         groups=bottleneck_channels) | 
|  | 40 | +        self.bn4 = nn.BatchNorm2d(bottleneck_channels) | 
|  | 41 | +        self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels, | 
|  | 42 | +                                     1, stride=1, groups=groups) | 
|  | 43 | +        self.bn6 = nn.BatchNorm2d(out_channels) | 
|  | 44 | + | 
|  | 45 | +    def forward(self, x): | 
|  | 46 | +        out = self.group_conv1(x) | 
|  | 47 | +        out = F.relu(self.bn2(out)) | 
|  | 48 | +        out = shuffle_channels(out, groups=self.groups) | 
|  | 49 | +        out = self.depthwise_conv3(out) | 
|  | 50 | +        out = self.bn4(out) | 
|  | 51 | +        out = self.group_conv5(out) | 
|  | 52 | +        out = self.bn6(out) | 
|  | 53 | +        out = F.relu(x + out) | 
|  | 54 | +        return out | 
|  | 55 | + | 
|  | 56 | +class ShuffleNetUnitB(nn.Module): | 
|  | 57 | +    """ShuffleNet unit for stride=2""" | 
|  | 58 | +    def __init__(self, in_channels, out_channels, groups=3): | 
|  | 59 | +        super(ShuffleNetUnitB, self).__init__() | 
|  | 60 | +        out_channels -= in_channels | 
|  | 61 | +        assert out_channels % 4 == 0 | 
|  | 62 | +        bottleneck_channels = out_channels // 4 | 
|  | 63 | +        self.groups = groups | 
|  | 64 | +        self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels, | 
|  | 65 | +                                     1, groups=groups, stride=1) | 
|  | 66 | +        self.bn2 = nn.BatchNorm2d(bottleneck_channels) | 
|  | 67 | +        self.depthwise_conv3 = nn.Conv2d(bottleneck_channels, | 
|  | 68 | +                                         bottleneck_channels, | 
|  | 69 | +                                         3, padding=1, stride=2, | 
|  | 70 | +                                         groups=bottleneck_channels) | 
|  | 71 | +        self.bn4 = nn.BatchNorm2d(bottleneck_channels) | 
|  | 72 | +        self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels, | 
|  | 73 | +                                     1, stride=1, groups=groups) | 
|  | 74 | +        self.bn6 = nn.BatchNorm2d(out_channels) | 
|  | 75 | + | 
|  | 76 | +    def forward(self, x): | 
|  | 77 | +        out = self.group_conv1(x) | 
|  | 78 | +        out = F.relu(self.bn2(out)) | 
|  | 79 | +        out = shuffle_channels(out, groups=self.groups) | 
|  | 80 | +        out = self.depthwise_conv3(out) | 
|  | 81 | +        out = self.bn4(out) | 
|  | 82 | +        out = self.group_conv5(out) | 
|  | 83 | +        out = self.bn6(out) | 
|  | 84 | +        x = F.avg_pool2d(x, 3, stride=2, padding=1) | 
|  | 85 | +        out = F.relu(torch.cat([x, out], dim=1)) | 
|  | 86 | +        return out | 
|  | 87 | + | 
|  | 88 | +class ShuffleNet(nn.Module): | 
|  | 89 | +    """ShuffleNet for groups=3""" | 
|  | 90 | +    def __init__(self, groups=3, in_channels=3, num_classes=1000): | 
|  | 91 | +        super(ShuffleNet, self).__init__() | 
|  | 92 | + | 
|  | 93 | +        self.conv1 = nn.Conv2d(in_channels, 24, 3, stride=2, padding=1) | 
|  | 94 | +        stage2_seq = [ShuffleNetUnitB(24, 240, groups=3)] + \ | 
|  | 95 | +            [ShuffleNetUnitA(240, 240, groups=3) for i in range(3)] | 
|  | 96 | +        self.stage2 = nn.Sequential(*stage2_seq) | 
|  | 97 | +        stage3_seq = [ShuffleNetUnitB(240, 480, groups=3)] + \ | 
|  | 98 | +            [ShuffleNetUnitA(480, 480, groups=3) for i in range(7)] | 
|  | 99 | +        self.stage3 = nn.Sequential(*stage3_seq) | 
|  | 100 | +        stage4_seq = [ShuffleNetUnitB(480, 960, groups=3)] + \ | 
|  | 101 | +                     [ShuffleNetUnitA(960, 960, groups=3) for i in range(3)] | 
|  | 102 | +        self.stage4 = nn.Sequential(*stage4_seq) | 
|  | 103 | +        self.fc = nn.Linear(960, num_classes) | 
|  | 104 | + | 
|  | 105 | +    def forward(self, x): | 
|  | 106 | +        net = self.conv1(x) | 
|  | 107 | +        net = F.max_pool2d(net, 3, stride=2, padding=1) | 
|  | 108 | +        net = self.stage2(net) | 
|  | 109 | +        net = self.stage3(net) | 
|  | 110 | +        net = self.stage4(net) | 
|  | 111 | +        net = F.avg_pool2d(net, 7) | 
|  | 112 | +        net = net.view(net.size(0), -1) | 
|  | 113 | +        net = self.fc(net) | 
|  | 114 | +        logits = F.softmax(net) | 
|  | 115 | +        return logits | 
|  | 116 | + | 
|  | 117 | +if __name__ == "__main__": | 
|  | 118 | +    x = Variable(torch.randn([32, 3, 224, 224]).type(dtype), | 
|  | 119 | +                 requires_grad=False) | 
|  | 120 | +    shuffleNet = ShuffleNet() | 
|  | 121 | +    out = shuffleNet(x) | 
|  | 122 | +    print(out.size()) | 
0 commit comments