@@ -101,7 +101,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_f
101
101
head_chs = builder .in_chs
102
102
103
103
# Head + Pooling
104
- self .global_pool = SelectAdaptivePool2d (pool_type = global_pool ) if global_pool else nn . Identity ()
104
+ self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
105
105
num_pooled_chs = head_chs * self .global_pool .feat_mult ()
106
106
self .conv_head = create_conv2d (num_pooled_chs , self .num_features , 1 , padding = pad_type , bias = head_bias )
107
107
self .act2 = act_layer (inplace = True )
@@ -122,7 +122,7 @@ def get_classifier(self):
122
122
def reset_classifier (self , num_classes , global_pool = 'avg' ):
123
123
self .num_classes = num_classes
124
124
# cannot meaningfully change pooling of efficient head after creation
125
- assert global_pool == self . global_pool . pool_type
125
+ self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
126
126
self .classifier = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
127
127
128
128
def forward_features (self , x ):
@@ -136,7 +136,9 @@ def forward_features(self, x):
136
136
return x
137
137
138
138
def forward (self , x ):
139
- x = self .forward_features (x ).flatten (1 )
139
+ x = self .forward_features (x )
140
+ if not self .global_pool .is_identity ():
141
+ x = x .flatten (1 )
140
142
if self .drop_rate > 0. :
141
143
x = F .dropout (x , p = self .drop_rate , training = self .training )
142
144
return self .classifier (x )
0 commit comments