| 
 | 1 | +"""  | 
 | 2 | +YOLOv2 implemented by Tensorflow, only for predicting  | 
 | 3 | +"""  | 
 | 4 | +import os  | 
 | 5 | + | 
 | 6 | +import numpy as np  | 
 | 7 | +import tensorflow as tf  | 
 | 8 | + | 
 | 9 | + | 
 | 10 | + | 
 | 11 | +######## basic layers #######  | 
 | 12 | + | 
 | 13 | +def leaky_relu(x):  | 
 | 14 | +    return tf.nn.leaky_relu(x, alpha=0.1, name="leaky_relu")  | 
 | 15 | + | 
 | 16 | +# Conv2d  | 
 | 17 | +def conv2d(x, filters, size, pad=0, stride=1, batch_normalize=1,  | 
 | 18 | +           activation=leaky_relu, use_bias=False, name="conv2d"):  | 
 | 19 | +    if pad > 0:  | 
 | 20 | +        x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])  | 
 | 21 | +    out = tf.layers.conv2d(x, filters, size, strides=stride, padding="VALID",  | 
 | 22 | +                           activation=None, use_bias=use_bias, name=name)  | 
 | 23 | +    if batch_normalize == 1:  | 
 | 24 | +        out = tf.layers.batch_normalization(out, axis=-1, momentum=0.9,  | 
 | 25 | +                                            training=False, name=name+"_bn")  | 
 | 26 | +    if activation:  | 
 | 27 | +        out = activation(out)  | 
 | 28 | +    return out  | 
 | 29 | + | 
 | 30 | +# maxpool2d  | 
 | 31 | +def maxpool(x, size=2, stride=2, name="maxpool"):  | 
 | 32 | +    return tf.layers.max_pooling2d(x, size, stride)  | 
 | 33 | + | 
 | 34 | +# reorg layer  | 
 | 35 | +def reorg(x, stride):  | 
 | 36 | +    return tf.extract_image_patches(x, [1, stride, stride, 1],  | 
 | 37 | +                        [1, stride, stride, 1], [1,1,1,1], padding="VALID")  | 
 | 38 | + | 
 | 39 | + | 
 | 40 | +def darknet(images, n_last_channels=425):  | 
 | 41 | +    """Darknet19 for YOLOv2"""  | 
 | 42 | +    net = conv2d(images, 32, 3, 1, name="conv1")  | 
 | 43 | +    net = maxpool(net, name="pool1")  | 
 | 44 | +    net = conv2d(net, 64, 3, 1, name="conv2")  | 
 | 45 | +    net = maxpool(net, name="pool2")  | 
 | 46 | +    net = conv2d(net, 128, 3, 1, name="conv3_1")  | 
 | 47 | +    net = conv2d(net, 64, 1, name="conv3_2")  | 
 | 48 | +    net = conv2d(net, 128, 3, 1, name="conv3_3")  | 
 | 49 | +    net = maxpool(net, name="pool3")  | 
 | 50 | +    net = conv2d(net, 256, 3, 1, name="conv4_1")  | 
 | 51 | +    net = conv2d(net, 128, 1, name="conv4_2")  | 
 | 52 | +    net = conv2d(net, 256, 3, 1, name="conv4_3")  | 
 | 53 | +    net = maxpool(net, name="pool4")  | 
 | 54 | +    net = conv2d(net, 512, 3, 1, name="conv5_1")  | 
 | 55 | +    net = conv2d(net, 256, 1, name="conv5_2")  | 
 | 56 | +    net = conv2d(net, 512, 3, 1, name="conv5_3")  | 
 | 57 | +    net = conv2d(net, 256, 1, name="conv5_4")  | 
 | 58 | +    net = conv2d(net, 512, 3, 1, name="conv5_5")  | 
 | 59 | +    shortcut = net  | 
 | 60 | +    net = maxpool(net, name="pool5")  | 
 | 61 | +    net = conv2d(net, 1024, 3, 1, name="conv6_1")  | 
 | 62 | +    net = conv2d(net, 512, 1, name="conv6_2")  | 
 | 63 | +    net = conv2d(net, 1024, 3, 1, name="conv6_3")  | 
 | 64 | +    net = conv2d(net, 512, 1, name="conv6_4")  | 
 | 65 | +    net = conv2d(net, 1024, 3, 1, name="conv6_5")  | 
 | 66 | +    # ---------  | 
 | 67 | +    net = conv2d(net, 1024, 3, 1, name="conv7_1")  | 
 | 68 | +    net = conv2d(net, 1024, 3, 1, name="conv7_2")  | 
 | 69 | +    # shortcut  | 
 | 70 | +    shortcut = conv2d(shortcut, 64, 1, name="conv_shortcut")  | 
 | 71 | +    shortcut = reorg(shortcut, 2)  | 
 | 72 | +    net = tf.concat([shortcut, net], axis=-1)  | 
 | 73 | +    net = conv2d(net, 1024, 3, 1, name="conv8")  | 
 | 74 | +    # detection layer  | 
 | 75 | +    net = conv2d(net, n_last_channels, 1, batch_normalize=0,  | 
 | 76 | +                 activation=None, use_bias=True, name="conv_dec")  | 
 | 77 | +    return net  | 
 | 78 | + | 
 | 79 | + | 
 | 80 | + | 
 | 81 | +if __name__ == "__main__":  | 
 | 82 | +    x = tf.random_normal([1, 416, 416, 3])  | 
 | 83 | +    model = darknet(x)  | 
 | 84 | + | 
 | 85 | +    saver = tf.train.Saver()  | 
 | 86 | +    with tf.Session() as sess:  | 
 | 87 | +        saver.restore(sess, "./checkpoint_dir/yolo2_coco.ckpt")  | 
 | 88 | +        print(sess.run(model).shape)  | 
 | 89 | + | 
0 commit comments