| 
 | 1 | +"""  | 
 | 2 | +2018-11-24  | 
 | 3 | +"""  | 
 | 4 | + | 
 | 5 | +from collections import namedtuple  | 
 | 6 | +import copy  | 
 | 7 | + | 
 | 8 | +import tensorflow as tf  | 
 | 9 | + | 
 | 10 | +slim = tf.contrib.slim  | 
 | 11 | + | 
 | 12 | +def _make_divisible(v, divisor, min_value=None):  | 
 | 13 | +    """make `v` is divided exactly by `divisor`, but keep the min_value"""  | 
 | 14 | +    if min_value is None:  | 
 | 15 | +        min_value = divisor  | 
 | 16 | +    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)  | 
 | 17 | +    # Make sure that round down does not go down by more than 10%.  | 
 | 18 | +    if new_v < 0.9 * v:  | 
 | 19 | +        new_v += divisor  | 
 | 20 | +    return new_v  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +@slim.add_arg_scope  | 
 | 24 | +def _depth_multiplier_func(params,  | 
 | 25 | +                           multiplier,  | 
 | 26 | +                           divisible_by=8,  | 
 | 27 | +                           min_depth=8):  | 
 | 28 | +    """get the new channles"""  | 
 | 29 | +    if 'num_outputs' not in params:  | 
 | 30 | +        return  | 
 | 31 | +    d = params['num_outputs']  | 
 | 32 | +    params['num_outputs'] = _make_divisible(d * multiplier, divisible_by,  | 
 | 33 | +                                                   min_depth)  | 
 | 34 | + | 
 | 35 | +def _fixed_padding(inputs, kernel_size, rate=1):  | 
 | 36 | +    """Pads the input along the spatial dimensions independently of input size.  | 
 | 37 | +      Pads the input such that if it was used in a convolution with 'VALID' padding,  | 
 | 38 | +      the output would have the same dimensions as if the unpadded input was used  | 
 | 39 | +      in a convolution with 'SAME' padding.  | 
 | 40 | +      Args:  | 
 | 41 | +        inputs: A tensor of size [batch, height_in, width_in, channels].  | 
 | 42 | +        kernel_size: The kernel to be used in the conv2d or max_pool2d operation.  | 
 | 43 | +        rate: An integer, rate for atrous convolution.  | 
 | 44 | +      Returns:  | 
 | 45 | +        output: A tensor of size [batch, height_out, width_out, channels] with the  | 
 | 46 | +        input, either intact (if kernel_size == 1) or padded (if kernel_size > 1).  | 
 | 47 | +    """  | 
 | 48 | +    kernel_size_effective = [kernel_size[0] + (kernel_size[0] - 1) * (rate - 1),  | 
 | 49 | +                               kernel_size[0] + (kernel_size[0] - 1) * (rate - 1)]  | 
 | 50 | +    pad_total = [kernel_size_effective[0] - 1, kernel_size_effective[1] - 1]  | 
 | 51 | +    pad_beg = [pad_total[0] // 2, pad_total[1] // 2]  | 
 | 52 | +    pad_end = [pad_total[0] - pad_beg[0], pad_total[1] - pad_beg[1]]  | 
 | 53 | +    padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg[0], pad_end[0]],  | 
 | 54 | +                                      [pad_beg[1], pad_end[1]], [0, 0]])  | 
 | 55 | +    return padded_inputs  | 
 | 56 | + | 
 | 57 | + | 
 | 58 | +@slim.add_arg_scope  | 
 | 59 | +def expanded_conv(x,  | 
 | 60 | +                  num_outputs,  | 
 | 61 | +                  expansion=6,  | 
 | 62 | +                  stride=1,  | 
 | 63 | +                  rate=1,  | 
 | 64 | +                  normalizer_fn=slim.batch_norm,  | 
 | 65 | +                  project_activation_fn=tf.identity,  | 
 | 66 | +                  padding="SAME",  | 
 | 67 | +                  scope=None):  | 
 | 68 | +    """The expand conv op in MobileNetv2  | 
 | 69 | +        1x1 conv -> depthwise 3x3 conv -> 1x1 linear conv  | 
 | 70 | +    """  | 
 | 71 | +    with tf.variable_scope(scope, default_name="expanded_conv") as s, \  | 
 | 72 | +       tf.name_scope(s.original_name_scope):  | 
 | 73 | +        prev_depth = x.get_shape().as_list()[3]  | 
 | 74 | +        # the filters of expanded conv  | 
 | 75 | +        inner_size = prev_depth * expansion  | 
 | 76 | +        net = x  | 
 | 77 | +        # only inner_size > prev_depth, use expanded conv  | 
 | 78 | +        if inner_size > prev_depth:  | 
 | 79 | +            net = slim.conv2d(net, inner_size, 1, normalizer_fn=normalizer_fn,  | 
 | 80 | +                              scope="expand")  | 
 | 81 | +        # depthwise conv  | 
 | 82 | +        net = slim.separable_conv2d(net, num_outputs=None, kernel_size=3,  | 
 | 83 | +                                    depth_multiplier=1, stride=stride,  | 
 | 84 | +                                    rate=rate, normalizer_fn=normalizer_fn,  | 
 | 85 | +                                    padding=padding, scope="depthwise")  | 
 | 86 | +        # projection  | 
 | 87 | +        net = slim.conv2d(net, num_outputs, 1, normalizer_fn=normalizer_fn,  | 
 | 88 | +                          activation_fn=project_activation_fn, scope="project")  | 
 | 89 | + | 
 | 90 | +        # residual connection  | 
 | 91 | +        if stride == 1 and net.get_shape().as_list()[-1] == prev_depth:  | 
 | 92 | +            net += x  | 
 | 93 | + | 
 | 94 | +        return net  | 
 | 95 | + | 
 | 96 | +def global_pool(x, pool_op=tf.nn.avg_pool):  | 
 | 97 | +    """Applies avg pool to produce 1x1 output.  | 
 | 98 | +    NOTE: This function is funcitonally equivalenet to reduce_mean, but it has  | 
 | 99 | +        baked in average pool which has better support across hardware.  | 
 | 100 | +    Args:  | 
 | 101 | +        input_tensor: input tensor  | 
 | 102 | +        pool_op: pooling op (avg pool is default)  | 
 | 103 | +    Returns:  | 
 | 104 | +        a tensor batch_size x 1 x 1 x depth.  | 
 | 105 | +    """  | 
 | 106 | +    shape = x.get_shape().as_list()  | 
 | 107 | +    if shape[1] is None or shape[2] is None:  | 
 | 108 | +        kernel_size = tf.convert_to_tensor(  | 
 | 109 | +            [1, tf.shape(x)[1], tf.shape(x)[2], 1])  | 
 | 110 | +    else:  | 
 | 111 | +        kernel_size = [1, shape[1], shape[2], 1]  | 
 | 112 | +    output = pool_op(x, ksize=kernel_size, strides=[1, 1, 1, 1], padding='VALID')  | 
 | 113 | +    # Recover output shape, for unknown shape.  | 
 | 114 | +    output.set_shape([None, 1, 1, None])  | 
 | 115 | +    return output  | 
 | 116 | + | 
 | 117 | + | 
 | 118 | +_Op = namedtuple("Op", ['op', 'params', 'multiplier_func'])  | 
 | 119 | + | 
 | 120 | +def op(op_func, **params):  | 
 | 121 | +    return _Op(op=op_func, params=params,  | 
 | 122 | +               multiplier_func=_depth_multiplier_func)  | 
 | 123 | + | 
 | 124 | + | 
 | 125 | +CONV_DEF = [op(slim.conv2d, num_outputs=32, stride=2, kernel_size=3),  | 
 | 126 | +            op(expanded_conv, num_outputs=16, expansion=1),  | 
 | 127 | +            op(expanded_conv, num_outputs=24, stride=2),  | 
 | 128 | +            op(expanded_conv, num_outputs=24, stride=1),  | 
 | 129 | +            op(expanded_conv, num_outputs=32, stride=2),  | 
 | 130 | +            op(expanded_conv, num_outputs=32, stride=1),  | 
 | 131 | +            op(expanded_conv, num_outputs=32, stride=1),  | 
 | 132 | +            op(expanded_conv, num_outputs=64, stride=2),  | 
 | 133 | +            op(expanded_conv, num_outputs=64, stride=1),  | 
 | 134 | +            op(expanded_conv, num_outputs=64, stride=1),  | 
 | 135 | +            op(expanded_conv, num_outputs=64, stride=1),  | 
 | 136 | +            op(expanded_conv, num_outputs=96, stride=1),  | 
 | 137 | +            op(expanded_conv, num_outputs=96, stride=1),  | 
 | 138 | +            op(expanded_conv, num_outputs=96, stride=1),  | 
 | 139 | +            op(expanded_conv, num_outputs=160, stride=2),  | 
 | 140 | +            op(expanded_conv, num_outputs=160, stride=1),  | 
 | 141 | +            op(expanded_conv, num_outputs=160, stride=1),  | 
 | 142 | +            op(expanded_conv, num_outputs=320, stride=1),  | 
 | 143 | +            op(slim.conv2d, num_outputs=1280, stride=1, kernel_size=1),  | 
 | 144 | +            ]  | 
 | 145 | + | 
 | 146 | + | 
 | 147 | +def mobilenet_arg_scope(is_training=True,  | 
 | 148 | +                        weight_decay=0.00004,  | 
 | 149 | +                        stddev=0.09,  | 
 | 150 | +                        dropout_keep_prob=0.8,  | 
 | 151 | +                        bn_decay=0.997):  | 
 | 152 | +    """Defines Mobilenet default arg scope.  | 
 | 153 | +    Usage:  | 
 | 154 | +     with tf.contrib.slim.arg_scope(mobilenet.training_scope()):  | 
 | 155 | +       logits, endpoints = mobilenet_v2.mobilenet(input_tensor)  | 
 | 156 | +     # the network created will be trainble with dropout/batch norm  | 
 | 157 | +     # initialized appropriately.  | 
 | 158 | +    Args:  | 
 | 159 | +        is_training: if set to False this will ensure that all customizations are  | 
 | 160 | +            set to non-training mode. This might be helpful for code that is reused  | 
 | 161 | +        across both training/evaluation, but most of the time training_scope with  | 
 | 162 | +        value False is not needed. If this is set to None, the parameters is not  | 
 | 163 | +        added to the batch_norm arg_scope.  | 
 | 164 | +        weight_decay: The weight decay to use for regularizing the model.  | 
 | 165 | +        stddev: Standard deviation for initialization, if negative uses xavier.  | 
 | 166 | +        dropout_keep_prob: dropout keep probability (not set if equals to None).  | 
 | 167 | +        bn_decay: decay for the batch norm moving averages (not set if equals to  | 
 | 168 | +            None).  | 
 | 169 | +    Returns:  | 
 | 170 | +        An argument scope to use via arg_scope.  | 
 | 171 | +    """  | 
 | 172 | +    # Note: do not introduce parameters that would change the inference  | 
 | 173 | +    # model here (for example whether to use bias), modify conv_def instead.  | 
 | 174 | +    batch_norm_params = {  | 
 | 175 | +        'center': True,  | 
 | 176 | +        'scale': True,  | 
 | 177 | +        'decay': bn_decay,  | 
 | 178 | +        'is_training': is_training  | 
 | 179 | +    }  | 
 | 180 | +    if stddev < 0:  | 
 | 181 | +        weight_intitializer = slim.initializers.xavier_initializer()  | 
 | 182 | +    else:  | 
 | 183 | +        weight_intitializer = tf.truncated_normal_initializer(stddev=stddev)  | 
 | 184 | + | 
 | 185 | +    # Set weight_decay for weights in Conv and FC layers.  | 
 | 186 | +    with slim.arg_scope(  | 
 | 187 | +        [slim.conv2d, slim.fully_connected, slim.separable_conv2d],  | 
 | 188 | +        weights_initializer=weight_intitializer,  | 
 | 189 | +        normalizer_fn=slim.batch_norm,  | 
 | 190 | +        activation_fn=tf.nn.relu6), \  | 
 | 191 | +        slim.arg_scope([slim.batch_norm], **batch_norm_params), \  | 
 | 192 | +        slim.arg_scope([slim.dropout], is_training=is_training,  | 
 | 193 | +                     keep_prob=dropout_keep_prob), \  | 
 | 194 | +        slim.arg_scope([slim.conv2d, slim.separable_conv2d],  | 
 | 195 | +                       biases_initializer=None,  | 
 | 196 | +                       padding="SAME"), \  | 
 | 197 | +        slim.arg_scope([slim.conv2d],  | 
 | 198 | +                     weights_regularizer=slim.l2_regularizer(weight_decay)), \  | 
 | 199 | +        slim.arg_scope([slim.separable_conv2d], weights_regularizer=None) as s:  | 
 | 200 | +        return s  | 
 | 201 | + | 
 | 202 | + | 
 | 203 | +def mobilenetv2(x,  | 
 | 204 | +                num_classes=1001,  | 
 | 205 | +                depth_multiplier=1.0,  | 
 | 206 | +                scope='MobilenetV2',  | 
 | 207 | +                finegrain_classification_mode=False,  | 
 | 208 | +                min_depth=8,  | 
 | 209 | +                divisible_by=8,  | 
 | 210 | +                output_stride=None,  | 
 | 211 | +                ):  | 
 | 212 | +    """Mobilenet v2  | 
 | 213 | +    Args:  | 
 | 214 | +        x: The input tensor  | 
 | 215 | +        num_classes: number of classes  | 
 | 216 | +        depth_multiplier: The multiplier applied to scale number of  | 
 | 217 | +            channels in each layer. Note: this is called depth multiplier in the  | 
 | 218 | +            paper but the name is kept for consistency with slim's model builder.  | 
 | 219 | +        scope: Scope of the operator  | 
 | 220 | +        finegrain_classification_mode: When set to True, the model  | 
 | 221 | +            will keep the last layer large even for small multipliers.  | 
 | 222 | +            The paper suggests that it improves performance for ImageNet-type of problems.  | 
 | 223 | +        min_depth: If provided, will ensure that all layers will have that  | 
 | 224 | +          many channels after application of depth multiplier.  | 
 | 225 | +       divisible_by: If provided will ensure that all layers # channels  | 
 | 226 | +          will be divisible by this number.  | 
 | 227 | +    """  | 
 | 228 | +    conv_defs = CONV_DEF  | 
 | 229 | + | 
 | 230 | +    # keep the last conv layer very larger channel  | 
 | 231 | +    if finegrain_classification_mode:  | 
 | 232 | +        conv_defs = copy.deepcopy(conv_defs)  | 
 | 233 | +        if depth_multiplier < 1:  | 
 | 234 | +            conv_defs[-1].params['num_outputs'] /= depth_multiplier  | 
 | 235 | + | 
 | 236 | +    depth_args = {}  | 
 | 237 | +    # NB: do not set depth_args unless they are provided to avoid overriding  | 
 | 238 | +    # whatever default depth_multiplier might have thanks to arg_scope.  | 
 | 239 | +    if min_depth is not None:  | 
 | 240 | +        depth_args['min_depth'] = min_depth  | 
 | 241 | +    if divisible_by is not None:  | 
 | 242 | +        depth_args['divisible_by'] = divisible_by  | 
 | 243 | + | 
 | 244 | +    with slim.arg_scope([_depth_multiplier_func], **depth_args):  | 
 | 245 | +        with tf.variable_scope(scope, default_name='Mobilenet'):  | 
 | 246 | +            # The current_stride variable keeps track of the output stride of the  | 
 | 247 | +            # activations, i.e., the running product of convolution strides up to the  | 
 | 248 | +            # current network layer. This allows us to invoke atrous convolution  | 
 | 249 | +            # whenever applying the next convolution would result in the activations  | 
 | 250 | +            # having output stride larger than the target output_stride.  | 
 | 251 | +            current_stride = 1  | 
 | 252 | + | 
 | 253 | +            # The atrous convolution rate parameter.  | 
 | 254 | +            rate = 1  | 
 | 255 | + | 
 | 256 | +            net = x  | 
 | 257 | +            # Insert default parameters before the base scope which includes  | 
 | 258 | +            # any custom overrides set in mobilenet.  | 
 | 259 | +            end_points = {}  | 
 | 260 | +            scopes = {}  | 
 | 261 | +            for i, opdef in enumerate(conv_defs):  | 
 | 262 | +                params = dict(opdef.params)  | 
 | 263 | +                opdef.multiplier_func(params, depth_multiplier)  | 
 | 264 | +                stride = params.get('stride', 1)  | 
 | 265 | +                if output_stride is not None and current_stride == output_stride:  | 
 | 266 | +                    # If we have reached the target output_stride, then we need to employ  | 
 | 267 | +                    # atrous convolution with stride=1 and multiply the atrous rate by the  | 
 | 268 | +                    # current unit's stride for use in subsequent layers.  | 
 | 269 | +                    layer_stride = 1  | 
 | 270 | +                    layer_rate = rate  | 
 | 271 | +                    rate *= stride  | 
 | 272 | +                else:  | 
 | 273 | +                    layer_stride = stride  | 
 | 274 | +                    layer_rate = 1  | 
 | 275 | +                    current_stride *= stride  | 
 | 276 | +                # Update params.  | 
 | 277 | +                params['stride'] = layer_stride  | 
 | 278 | +                # Only insert rate to params if rate > 1.  | 
 | 279 | +                if layer_rate > 1:  | 
 | 280 | +                    params['rate'] = layer_rate  | 
 | 281 | + | 
 | 282 | +                try:  | 
 | 283 | +                    net = opdef.op(net, **params)  | 
 | 284 | +                except Exception:  | 
 | 285 | +                    raise ValueError('Failed to create op %i: %r params: %r' % (i, opdef, params))  | 
 | 286 | + | 
 | 287 | +            with tf.variable_scope('Logits'):  | 
 | 288 | +                net = global_pool(net)  | 
 | 289 | +                end_points['global_pool'] = net  | 
 | 290 | +                if not num_classes:  | 
 | 291 | +                    return net, end_points  | 
 | 292 | +                net = slim.dropout(net, scope='Dropout')  | 
 | 293 | +                # 1 x 1 x num_classes  | 
 | 294 | +                # Note: legacy scope name.  | 
 | 295 | +                logits = slim.conv2d(  | 
 | 296 | +                    net,  | 
 | 297 | +                    num_classes, [1, 1],  | 
 | 298 | +                    activation_fn=None,  | 
 | 299 | +                    normalizer_fn=None,  | 
 | 300 | +                    biases_initializer=tf.zeros_initializer(),  | 
 | 301 | +                    scope='Conv2d_1c_1x1')  | 
 | 302 | + | 
 | 303 | +                logits = tf.squeeze(logits, [1, 2])  | 
 | 304 | + | 
 | 305 | +                return logits  | 
 | 306 | + | 
 | 307 | + | 
 | 308 | +if __name__ == "__main__":  | 
 | 309 | +    import cv2  | 
 | 310 | +    import numpy as np  | 
 | 311 | + | 
 | 312 | +    inputs = tf.placeholder(tf.uint8, [None, None, 3])  | 
 | 313 | +    images = tf.expand_dims(inputs, 0)  | 
 | 314 | +    images = tf.cast(images, tf.float32) / 128. - 1  | 
 | 315 | +    images.set_shape((None, None, None, 3))  | 
 | 316 | +    images = tf.image.resize_images(images, (224, 224))  | 
 | 317 | + | 
 | 318 | +    with slim.arg_scope(mobilenet_arg_scope(is_training=False)):  | 
 | 319 | +        logits = mobilenetv2(images)  | 
 | 320 | + | 
 | 321 | +    # Restore using exponential moving average since it produces (1.5-2%) higher  | 
 | 322 | +    # accuracy  | 
 | 323 | +    ema = tf.train.ExponentialMovingAverage(0.999)  | 
 | 324 | +    vars = ema.variables_to_restore()  | 
 | 325 | + | 
 | 326 | +    saver = tf.train.Saver(vars)  | 
 | 327 | + | 
 | 328 | +    print(len(tf.global_variables()))  | 
 | 329 | +    for var in tf.global_variables():  | 
 | 330 | +        print(var)  | 
 | 331 | +    checkpoint_path = r"C:\Users\xiaoh\Desktop\temp\mobilenet_v2_1.0_224\mobilenet_v2_1.0_224.ckpt"  | 
 | 332 | +    image_file = "C:/Users/xiaoh/Desktop/temp/pandas.jpg"  | 
 | 333 | +    with tf.Session() as sess:  | 
 | 334 | +        saver.restore(sess, checkpoint_path)  | 
 | 335 | + | 
 | 336 | +        img = cv2.imread(image_file)  | 
 | 337 | +        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  | 
 | 338 | + | 
 | 339 | +        print(np.argmax(sess.run(logits, feed_dict={inputs: img})[0]))  | 
 | 340 | + | 
 | 341 | + | 
 | 342 | + | 
 | 343 | + | 
 | 344 | + | 
 | 345 | + | 
 | 346 | + | 
 | 347 | + | 
 | 348 | + | 
 | 349 | + | 
0 commit comments