| 
 | 1 | +"""  | 
 | 2 | +implement a shuffleNet by pytorch  | 
 | 3 | +"""  | 
 | 4 | +import torch  | 
 | 5 | +import torch.nn as nn  | 
 | 6 | +import torch.nn.functional as F  | 
 | 7 | +from torch.autograd import Variable  | 
 | 8 | + | 
 | 9 | +dtype = torch.FloatTensor  | 
 | 10 | + | 
 | 11 | +def shuffle_channels(x, groups):  | 
 | 12 | +    """shuffle channels of a 4-D Tensor"""  | 
 | 13 | +    batch_size, channels, height, width = x.size()  | 
 | 14 | +    assert channels % groups == 0  | 
 | 15 | +    channels_per_group = channels // groups  | 
 | 16 | +    # split into groups  | 
 | 17 | +    x = x.view(batch_size, groups, channels_per_group,  | 
 | 18 | +               height, width)  | 
 | 19 | +    # transpose 1, 2 axis  | 
 | 20 | +    x = x.transpose(1, 2).contiguous()  | 
 | 21 | +    # reshape into orignal  | 
 | 22 | +    x = x.view(batch_size, channels, height, width)  | 
 | 23 | +    return x  | 
 | 24 | + | 
 | 25 | +class ShuffleNetUnitA(nn.Module):  | 
 | 26 | +    """ShuffleNet unit for stride=1"""  | 
 | 27 | +    def __init__(self, in_channels, out_channels, groups=3):  | 
 | 28 | +        super(ShuffleNetUnitA, self).__init__()  | 
 | 29 | +        assert in_channels == out_channels  | 
 | 30 | +        assert out_channels % 4 == 0  | 
 | 31 | +        bottleneck_channels = out_channels // 4  | 
 | 32 | +        self.groups = groups  | 
 | 33 | +        self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,  | 
 | 34 | +                                        1, groups=groups, stride=1)  | 
 | 35 | +        self.bn2 = nn.BatchNorm2d(bottleneck_channels)  | 
 | 36 | +        self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,  | 
 | 37 | +                                         bottleneck_channels,  | 
 | 38 | +                                         3, padding=1, stride=1,  | 
 | 39 | +                                         groups=bottleneck_channels)  | 
 | 40 | +        self.bn4 = nn.BatchNorm2d(bottleneck_channels)  | 
 | 41 | +        self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,  | 
 | 42 | +                                     1, stride=1, groups=groups)  | 
 | 43 | +        self.bn6 = nn.BatchNorm2d(out_channels)  | 
 | 44 | + | 
 | 45 | +    def forward(self, x):  | 
 | 46 | +        out = self.group_conv1(x)  | 
 | 47 | +        out = F.relu(self.bn2(out))  | 
 | 48 | +        out = shuffle_channels(out, groups=self.groups)  | 
 | 49 | +        out = self.depthwise_conv3(out)  | 
 | 50 | +        out = self.bn4(out)  | 
 | 51 | +        out = self.group_conv5(out)  | 
 | 52 | +        out = self.bn6(out)  | 
 | 53 | +        out = F.relu(x + out)  | 
 | 54 | +        return out  | 
 | 55 | + | 
 | 56 | +class ShuffleNetUnitB(nn.Module):  | 
 | 57 | +    """ShuffleNet unit for stride=2"""  | 
 | 58 | +    def __init__(self, in_channels, out_channels, groups=3):  | 
 | 59 | +        super(ShuffleNetUnitB, self).__init__()  | 
 | 60 | +        out_channels -= in_channels  | 
 | 61 | +        assert out_channels % 4 == 0  | 
 | 62 | +        bottleneck_channels = out_channels // 4  | 
 | 63 | +        self.groups = groups  | 
 | 64 | +        self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,  | 
 | 65 | +                                     1, groups=groups, stride=1)  | 
 | 66 | +        self.bn2 = nn.BatchNorm2d(bottleneck_channels)  | 
 | 67 | +        self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,  | 
 | 68 | +                                         bottleneck_channels,  | 
 | 69 | +                                         3, padding=1, stride=2,  | 
 | 70 | +                                         groups=bottleneck_channels)  | 
 | 71 | +        self.bn4 = nn.BatchNorm2d(bottleneck_channels)  | 
 | 72 | +        self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,  | 
 | 73 | +                                     1, stride=1, groups=groups)  | 
 | 74 | +        self.bn6 = nn.BatchNorm2d(out_channels)  | 
 | 75 | + | 
 | 76 | +    def forward(self, x):  | 
 | 77 | +        out = self.group_conv1(x)  | 
 | 78 | +        out = F.relu(self.bn2(out))  | 
 | 79 | +        out = shuffle_channels(out, groups=self.groups)  | 
 | 80 | +        out = self.depthwise_conv3(out)  | 
 | 81 | +        out = self.bn4(out)  | 
 | 82 | +        out = self.group_conv5(out)  | 
 | 83 | +        out = self.bn6(out)  | 
 | 84 | +        x = F.avg_pool2d(x, 3, stride=2, padding=1)  | 
 | 85 | +        out = F.relu(torch.cat([x, out], dim=1))  | 
 | 86 | +        return out  | 
 | 87 | + | 
 | 88 | +class ShuffleNet(nn.Module):  | 
 | 89 | +    """ShuffleNet for groups=3"""  | 
 | 90 | +    def __init__(self, groups=3, in_channels=3, num_classes=1000):  | 
 | 91 | +        super(ShuffleNet, self).__init__()  | 
 | 92 | + | 
 | 93 | +        self.conv1 = nn.Conv2d(in_channels, 24, 3, stride=2, padding=1)  | 
 | 94 | +        stage2_seq = [ShuffleNetUnitB(24, 240, groups=3)] + \  | 
 | 95 | +            [ShuffleNetUnitA(240, 240, groups=3) for i in range(3)]  | 
 | 96 | +        self.stage2 = nn.Sequential(*stage2_seq)  | 
 | 97 | +        stage3_seq = [ShuffleNetUnitB(240, 480, groups=3)] + \  | 
 | 98 | +            [ShuffleNetUnitA(480, 480, groups=3) for i in range(7)]  | 
 | 99 | +        self.stage3 = nn.Sequential(*stage3_seq)  | 
 | 100 | +        stage4_seq = [ShuffleNetUnitB(480, 960, groups=3)] + \  | 
 | 101 | +                     [ShuffleNetUnitA(960, 960, groups=3) for i in range(3)]  | 
 | 102 | +        self.stage4 = nn.Sequential(*stage4_seq)  | 
 | 103 | +        self.fc = nn.Linear(960, num_classes)  | 
 | 104 | + | 
 | 105 | +    def forward(self, x):  | 
 | 106 | +        net = self.conv1(x)  | 
 | 107 | +        net = F.max_pool2d(net, 3, stride=2, padding=1)  | 
 | 108 | +        net = self.stage2(net)  | 
 | 109 | +        net = self.stage3(net)  | 
 | 110 | +        net = self.stage4(net)  | 
 | 111 | +        net = F.avg_pool2d(net, 7)  | 
 | 112 | +        net = net.view(net.size(0), -1)  | 
 | 113 | +        net = self.fc(net)  | 
 | 114 | +        logits = F.softmax(net)  | 
 | 115 | +        return logits  | 
 | 116 | + | 
 | 117 | +if __name__ == "__main__":  | 
 | 118 | +    x = Variable(torch.randn([32, 3, 224, 224]).type(dtype),  | 
 | 119 | +                 requires_grad=False)  | 
 | 120 | +    shuffleNet = ShuffleNet()  | 
 | 121 | +    out = shuffleNet(x)  | 
 | 122 | +    print(out.size())  | 
0 commit comments