|  | 
|  | 1 | +""" | 
|  | 2 | +YOLOv2 implemented by Tensorflow, only for predicting | 
|  | 3 | +""" | 
|  | 4 | +import os | 
|  | 5 | + | 
|  | 6 | +import numpy as np | 
|  | 7 | +import tensorflow as tf | 
|  | 8 | + | 
|  | 9 | + | 
|  | 10 | + | 
|  | 11 | +######## basic layers ####### | 
|  | 12 | + | 
|  | 13 | +def leaky_relu(x): | 
|  | 14 | +    return tf.nn.leaky_relu(x, alpha=0.1, name="leaky_relu") | 
|  | 15 | + | 
|  | 16 | +# Conv2d | 
|  | 17 | +def conv2d(x, filters, size, pad=0, stride=1, batch_normalize=1, | 
|  | 18 | +           activation=leaky_relu, use_bias=False, name="conv2d"): | 
|  | 19 | +    if pad > 0: | 
|  | 20 | +        x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) | 
|  | 21 | +    out = tf.layers.conv2d(x, filters, size, strides=stride, padding="VALID", | 
|  | 22 | +                           activation=None, use_bias=use_bias, name=name) | 
|  | 23 | +    if batch_normalize == 1: | 
|  | 24 | +        out = tf.layers.batch_normalization(out, axis=-1, momentum=0.9, | 
|  | 25 | +                                            training=False, name=name+"_bn") | 
|  | 26 | +    if activation: | 
|  | 27 | +        out = activation(out) | 
|  | 28 | +    return out | 
|  | 29 | + | 
|  | 30 | +# maxpool2d | 
|  | 31 | +def maxpool(x, size=2, stride=2, name="maxpool"): | 
|  | 32 | +    return tf.layers.max_pooling2d(x, size, stride) | 
|  | 33 | + | 
|  | 34 | +# reorg layer | 
|  | 35 | +def reorg(x, stride): | 
|  | 36 | +    return tf.extract_image_patches(x, [1, stride, stride, 1], | 
|  | 37 | +                        [1, stride, stride, 1], [1,1,1,1], padding="VALID") | 
|  | 38 | + | 
|  | 39 | + | 
|  | 40 | +def darknet(images, n_last_channels=425): | 
|  | 41 | +    """Darknet19 for YOLOv2""" | 
|  | 42 | +    net = conv2d(images, 32, 3, 1, name="conv1") | 
|  | 43 | +    net = maxpool(net, name="pool1") | 
|  | 44 | +    net = conv2d(net, 64, 3, 1, name="conv2") | 
|  | 45 | +    net = maxpool(net, name="pool2") | 
|  | 46 | +    net = conv2d(net, 128, 3, 1, name="conv3_1") | 
|  | 47 | +    net = conv2d(net, 64, 1, name="conv3_2") | 
|  | 48 | +    net = conv2d(net, 128, 3, 1, name="conv3_3") | 
|  | 49 | +    net = maxpool(net, name="pool3") | 
|  | 50 | +    net = conv2d(net, 256, 3, 1, name="conv4_1") | 
|  | 51 | +    net = conv2d(net, 128, 1, name="conv4_2") | 
|  | 52 | +    net = conv2d(net, 256, 3, 1, name="conv4_3") | 
|  | 53 | +    net = maxpool(net, name="pool4") | 
|  | 54 | +    net = conv2d(net, 512, 3, 1, name="conv5_1") | 
|  | 55 | +    net = conv2d(net, 256, 1, name="conv5_2") | 
|  | 56 | +    net = conv2d(net, 512, 3, 1, name="conv5_3") | 
|  | 57 | +    net = conv2d(net, 256, 1, name="conv5_4") | 
|  | 58 | +    net = conv2d(net, 512, 3, 1, name="conv5_5") | 
|  | 59 | +    shortcut = net | 
|  | 60 | +    net = maxpool(net, name="pool5") | 
|  | 61 | +    net = conv2d(net, 1024, 3, 1, name="conv6_1") | 
|  | 62 | +    net = conv2d(net, 512, 1, name="conv6_2") | 
|  | 63 | +    net = conv2d(net, 1024, 3, 1, name="conv6_3") | 
|  | 64 | +    net = conv2d(net, 512, 1, name="conv6_4") | 
|  | 65 | +    net = conv2d(net, 1024, 3, 1, name="conv6_5") | 
|  | 66 | +    # --------- | 
|  | 67 | +    net = conv2d(net, 1024, 3, 1, name="conv7_1") | 
|  | 68 | +    net = conv2d(net, 1024, 3, 1, name="conv7_2") | 
|  | 69 | +    # shortcut | 
|  | 70 | +    shortcut = conv2d(shortcut, 64, 1, name="conv_shortcut") | 
|  | 71 | +    shortcut = reorg(shortcut, 2) | 
|  | 72 | +    net = tf.concat([shortcut, net], axis=-1) | 
|  | 73 | +    net = conv2d(net, 1024, 3, 1, name="conv8") | 
|  | 74 | +    # detection layer | 
|  | 75 | +    net = conv2d(net, n_last_channels, 1, batch_normalize=0, | 
|  | 76 | +                 activation=None, use_bias=True, name="conv_dec") | 
|  | 77 | +    return net | 
|  | 78 | + | 
|  | 79 | + | 
|  | 80 | + | 
|  | 81 | +if __name__ == "__main__": | 
|  | 82 | +    x = tf.random_normal([1, 416, 416, 3]) | 
|  | 83 | +    model = darknet(x) | 
|  | 84 | + | 
|  | 85 | +    saver = tf.train.Saver() | 
|  | 86 | +    with tf.Session() as sess: | 
|  | 87 | +        saver.restore(sess, "./checkpoint_dir/yolo2_coco.ckpt") | 
|  | 88 | +        print(sess.run(model).shape) | 
|  | 89 | + | 
0 commit comments