1+ """
2+ ResNet50
3+ 2017/12/06
4+ """
5+
6+ import tensorflow as tf
7+ from tensorflow .python .training import moving_averages
8+
9+ fc_initializer = tf .contrib .layers .xavier_initializer
10+ conv2d_initializer = tf .contrib .layers .xavier_initializer_conv2d
11+
12+ # create weight variable
13+ def create_var (name , shape , initializer , trainable = True ):
14+ return tf .get_variable (name , shape = shape , dtype = tf .float32 ,
15+ initializer = initializer , trainable = trainable )
16+
17+ # conv2d layer
18+ def conv2d (x , num_outputs , kernel_size , stride = 1 , scope = "conv2d" ):
19+ num_inputs = x .get_shape ()[- 1 ]
20+ with tf .variable_scope (scope ):
21+ kernel = create_var ("kernel" , [kernel_size , kernel_size ,
22+ num_inputs , num_outputs ],
23+ conv2d_initializer ())
24+ return tf .nn .conv2d (x , kernel , strides = [1 , stride , stride , 1 ],
25+ padding = "SAME" )
26+
27+ # fully connected layer
28+ def fc (x , num_outputs , scope = "fc" ):
29+ num_inputs = x .get_shape ()[- 1 ]
30+ with tf .variable_scope (scope ):
31+ weight = create_var ("weight" , [num_inputs , num_outputs ],
32+ fc_initializer ())
33+ bias = create_var ("bias" , [num_outputs ,],
34+ tf .zeros_initializer ())
35+ return tf .nn .xw_plus_b (x , weight , bias )
36+
37+
38+ # batch norm layer
39+ def batch_norm (x , decay = 0.999 , epsilon = 1e-03 , is_training = True ,
40+ scope = "scope" ):
41+ x_shape = x .get_shape ()
42+ num_inputs = x_shape [- 1 ]
43+ reduce_dims = list (range (len (x_shape ) - 1 ))
44+ with tf .variable_scope (scope ):
45+ beta = create_var ("beta" , [num_inputs ,],
46+ initializer = tf .zeros_initializer ())
47+ gamma = create_var ("gamma" , [num_inputs ,],
48+ initializer = tf .ones_initializer ())
49+ # for inference
50+ moving_mean = create_var ("moving_mean" , [num_inputs ,],
51+ initializer = tf .zeros_initializer (),
52+ trainable = False )
53+ moving_variance = create_var ("moving_variance" , [num_inputs ],
54+ initializer = tf .ones_initializer (),
55+ trainable = False )
56+ if is_training :
57+ mean , variance = tf .nn .moments (x , axes = reduce_dims )
58+ update_move_mean = moving_averages .assign_moving_average (moving_mean ,
59+ mean , decay = decay )
60+ update_move_variance = moving_averages .assign_moving_average (moving_variance ,
61+ variance , decay = decay )
62+ tf .add_to_collection (tf .GraphKeys .UPDATE_OPS , update_move_mean )
63+ tf .add_to_collection (tf .GraphKeys .UPDATE_OPS , update_move_variance )
64+ else :
65+ mean , variance = moving_mean , moving_variance
66+ return tf .nn .batch_normalization (x , mean , variance , beta , gamma , epsilon )
67+
68+
69+ # avg pool layer
70+ def avg_pool (x , pool_size , scope ):
71+ with tf .variable_scope (scope ):
72+ return tf .nn .avg_pool (x , [1 , pool_size , pool_size , 1 ],
73+ strides = [1 , pool_size , pool_size , 1 ], padding = "VALID" )
74+
75+ # max pool layer
76+ def max_pool (x , pool_size , stride , scope ):
77+ with tf .variable_scope (scope ):
78+ return tf .nn .max_pool (x , [1 , pool_size , pool_size , 1 ],
79+ [1 , stride , stride , 1 ], padding = "SAME" )
80+
81+ class ResNet50 (object ):
82+ def __init__ (self , inputs , num_classes = 1000 , is_training = True ,
83+ scope = "resnet50" ):
84+ self .inputs = inputs
85+ self .is_training = is_training
86+ self .num_classes = num_classes
87+
88+ with tf .variable_scope (scope ):
89+ # construct the model
90+ net = conv2d (inputs , 64 , 7 , 2 , scope = "conv1" ) # -> [batch, 112, 112, 64]
91+ net = tf .nn .relu (batch_norm (net , is_training = self .is_training , scope = "bn1" ))
92+ net = max_pool (net , 3 , 2 , scope = "maxpool1" ) # -> [batch, 56, 56, 64]
93+ net = self ._block (net , 256 , 3 , init_stride = 1 , is_training = self .is_training ,
94+ scope = "block2" ) # -> [batch, 56, 56, 256]
95+ net = self ._block (net , 512 , 4 , is_training = self .is_training , scope = "block3" )
96+ # -> [batch, 28, 28, 512]
97+ net = self ._block (net , 1024 , 6 , is_training = self .is_training , scope = "block4" )
98+ # -> [batch, 14, 14, 1024]
99+ net = self ._block (net , 2048 , 3 , is_training = self .is_training , scope = "block5" )
100+ # -> [batch, 7, 7, 2048]
101+ net = avg_pool (net , 7 , scope = "avgpool5" ) # -> [batch, 1, 1, 2048]
102+ net = tf .squeeze (net , [1 , 2 ], name = "SpatialSqueeze" ) # -> [batch, 2048]
103+ self .logits = fc (net , self .num_classes , "fc6" ) # -> [batch, num_classes]
104+ self .predictions = tf .nn .softmax (self .logits )
105+
106+
107+ def _block (self , x , n_out , n , init_stride = 2 , is_training = True , scope = "block" ):
108+ with tf .variable_scope (scope ):
109+ h_out = n_out // 4
110+ out = self ._bottleneck (x , h_out , n_out , stride = init_stride ,
111+ is_training = is_training , scope = "bottlencek1" )
112+ for i in range (1 , n ):
113+ out = self ._bottleneck (out , h_out , n_out , is_training = is_training ,
114+ scope = ("bottlencek%s" % (i + 1 )))
115+ return out
116+
117+ def _bottleneck (self , x , h_out , n_out , stride = None , is_training = True , scope = "bottleneck" ):
118+ """ A residual bottleneck unit"""
119+ n_in = x .get_shape ()[- 1 ]
120+ if stride is None :
121+ stride = 1 if n_in == n_out else 2
122+
123+ with tf .variable_scope (scope ):
124+ h = conv2d (x , h_out , 1 , stride = stride , scope = "conv_1" )
125+ h = batch_norm (h , is_training = is_training , scope = "bn_1" )
126+ h = tf .nn .relu (h )
127+ h = conv2d (h , h_out , 3 , stride = 1 , scope = "conv_2" )
128+ h = batch_norm (h , is_training = is_training , scope = "bn_2" )
129+ h = tf .nn .relu (h )
130+ h = conv2d (h , n_out , 1 , stride = 1 , scope = "conv_3" )
131+ h = batch_norm (h , is_training = is_training , scope = "bn_3" )
132+
133+ if n_in != n_out :
134+ shortcut = conv2d (x , n_out , 1 , stride = stride , scope = "conv_4" )
135+ shortcut = batch_norm (shortcut , is_training = is_training , scope = "bn_4" )
136+ else :
137+ shortcut = x
138+ return tf .nn .relu (shortcut + h )
139+
140+ if __name__ == "__main__" :
141+ x = tf .random_normal ([32 , 224 , 224 , 3 ])
142+ resnet50 = ResNet50 (x )
143+ print (resnet50 .logits )
0 commit comments