Skip to content

Commit 470220b

Browse files
committed
Fix MobileNetV3 crash with global_pool='', output consistent with other models but not equivalent due to efficient head.
1 parent fc8b8af commit 470220b

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

tests/test_models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,11 @@ def test_model_default_cfgs(model_name, batch_size):
9999
assert outputs.shape[-1] == model.num_features
100100

101101
# test model forward without pooling and classifier
102+
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
103+
outputs = model.forward(input_tensor)
104+
assert len(outputs.shape) == 4
102105
if not isinstance(model, timm.models.MobileNetV3):
103-
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
104-
outputs = model.forward(input_tensor)
105-
assert len(outputs.shape) == 4
106+
# FIXME mobilenetv3 forward_features vs removed pooling differ
106107
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
107108

108109
# check classifier and first convolution names match those in default_cfg

timm/models/mobilenetv3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_f
101101
head_chs = builder.in_chs
102102

103103
# Head + Pooling
104-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if global_pool else nn.Identity()
104+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
105105
num_pooled_chs = head_chs * self.global_pool.feat_mult()
106106
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
107107
self.act2 = act_layer(inplace=True)
@@ -122,7 +122,7 @@ def get_classifier(self):
122122
def reset_classifier(self, num_classes, global_pool='avg'):
123123
self.num_classes = num_classes
124124
# cannot meaningfully change pooling of efficient head after creation
125-
assert global_pool == self.global_pool.pool_type
125+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
126126
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
127127

128128
def forward_features(self, x):
@@ -136,7 +136,9 @@ def forward_features(self, x):
136136
return x
137137

138138
def forward(self, x):
139-
x = self.forward_features(x).flatten(1)
139+
x = self.forward_features(x)
140+
if not self.global_pool.is_identity():
141+
x = x.flatten(1)
140142
if self.drop_rate > 0.:
141143
x = F.dropout(x, p=self.drop_rate, training=self.training)
142144
return self.classifier(x)

0 commit comments

Comments
 (0)