Skip to content

Commit b1f1a54

Browse files
committed
More uniform treatment of classifiers across all models, reduce code duplication.
1 parent 9806f3e commit b1f1a54

22 files changed

+173
-207
lines changed

tests/test_models.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44
import os
55
import fnmatch
66

7+
import timm
78
from timm import list_models, create_model, set_scriptable
89

10+
if hasattr(torch._C, '_jit_set_profiling_executor'):
11+
# legacy executor is too slow to compile large models for unit tests
12+
# no need for the fusion performance here
13+
torch._C._jit_set_profiling_executor(True)
14+
torch._C._jit_set_profiling_mode(False)
915

1016
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
1117
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
@@ -78,10 +84,28 @@ def test_model_default_cfgs(model_name, batch_size):
7884

7985
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
8086
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
81-
# pool size only checked if default res <= 448 * 448 to keep resource down
87+
# output sizes only checked if default res <= 448 * 448 to keep resource down
8288
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
83-
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
89+
input_tensor = torch.randn((batch_size, *input_size))
90+
91+
# test forward_features (always unpooled)
92+
outputs = model.forward_features(input_tensor)
8493
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
94+
95+
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
96+
model.reset_classifier(0)
97+
outputs = model.forward(input_tensor)
98+
assert len(outputs.shape) == 2
99+
assert outputs.shape[-1] == model.num_features
100+
101+
# test model forward without pooling and classifier
102+
if not isinstance(model, timm.models.MobileNetV3):
103+
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
104+
outputs = model.forward(input_tensor)
105+
assert len(outputs.shape) == 4
106+
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
107+
108+
# check classifier and first convolution names match those in default_cfg
85109
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
86110
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'
87111

timm/models/densenet.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1616
from .helpers import build_model_with_cfg
17-
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
17+
from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier
1818
from .registry import register_model
1919

2020
__all__ = ['DenseNet']
@@ -236,8 +236,8 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem
236236
self.num_features = num_features
237237

238238
# Linear layer
239-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
240-
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
239+
self.global_pool, self.classifier = create_classifier(
240+
self.num_features, self.num_classes, pool_type=global_pool)
241241

242242
# Official init from torch repo.
243243
for m in self.modules():
@@ -254,19 +254,15 @@ def get_classifier(self):
254254

255255
def reset_classifier(self, num_classes, global_pool='avg'):
256256
self.num_classes = num_classes
257-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
258-
if num_classes:
259-
num_features = self.num_features * self.global_pool.feat_mult()
260-
self.classifier = nn.Linear(num_features, num_classes)
261-
else:
262-
self.classifier = nn.Identity()
257+
self.global_pool, self.classifier = create_classifier(
258+
self.num_features, self.num_classes, pool_type=global_pool)
263259

264260
def forward_features(self, x):
265261
return self.features(x)
266262

267263
def forward(self, x):
268264
x = self.forward_features(x)
269-
x = self.global_pool(x).flatten(1)
265+
x = self.global_pool(x)
270266
# both classifier and block drop?
271267
# if self.drop_rate > 0.:
272268
# x = F.dropout(x, p=self.drop_rate, training=self.training)

timm/models/dla.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1515
from .helpers import build_model_with_cfg
16-
from .layers import SelectAdaptivePool2d
16+
from .layers import create_classifier
1717
from .registry import register_model
1818

1919
__all__ = ['DLA']
@@ -286,9 +286,8 @@ def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chan
286286
]
287287

288288
self.num_features = channels[-1]
289-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
290-
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True)
291-
289+
self.global_pool, self.fc = create_classifier(
290+
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
292291
for m in self.modules():
293292
if isinstance(m, nn.Conv2d):
294293
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@@ -313,12 +312,8 @@ def get_classifier(self):
313312

314313
def reset_classifier(self, num_classes, global_pool='avg'):
315314
self.num_classes = num_classes
316-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
317-
if num_classes:
318-
num_features = self.num_features * self.global_pool.feat_mult()
319-
self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
320-
else:
321-
self.fc = nn.Identity()
315+
self.global_pool, self.fc = create_classifier(
316+
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
322317

323318
def forward_features(self, x):
324319
x = self.base_layer(x)
@@ -336,7 +331,9 @@ def forward(self, x):
336331
if self.drop_rate > 0.:
337332
x = F.dropout(x, p=self.drop_rate, training=self.training)
338333
x = self.fc(x)
339-
return x.flatten(1)
334+
if not self.global_pool.is_identity():
335+
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
336+
return x
340337

341338

342339
def _create_dla(variant, pretrained=False, **kwargs):

timm/models/dpn.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2121
from .helpers import build_model_with_cfg
22-
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_conv2d, ConvBnAct
22+
from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier
2323
from .registry import register_model
2424

2525
__all__ = ['DPN']
@@ -237,21 +237,16 @@ def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplac
237237
self.features = nn.Sequential(blocks)
238238

239239
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
240-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
241-
num_features = self.num_features * self.global_pool.feat_mult()
242-
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
240+
self.global_pool, self.classifier = create_classifier(
241+
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
243242

244243
def get_classifier(self):
245244
return self.classifier
246245

247246
def reset_classifier(self, num_classes, global_pool='avg'):
248247
self.num_classes = num_classes
249-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
250-
if num_classes:
251-
num_features = self.num_features * self.global_pool.feat_mult()
252-
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
253-
else:
254-
self.classifier = nn.Identity()
248+
self.global_pool, self.classifier = create_classifier(
249+
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
255250

256251
def forward_features(self, x):
257252
return self.features(x)
@@ -261,8 +256,10 @@ def forward(self, x):
261256
x = self.global_pool(x)
262257
if self.drop_rate > 0.:
263258
x = F.dropout(x, p=self.drop_rate, training=self.training)
264-
out = self.classifier(x)
265-
return out.flatten(1)
259+
x = self.classifier(x)
260+
if not self.global_pool.is_identity():
261+
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
262+
return x
266263

267264

268265
def _create_dpn(variant, pretrained=False, **kwargs):

timm/models/efficientnet.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
3636
from .features import FeatureInfo, FeatureHooks
3737
from .helpers import build_model_with_cfg
38-
from .layers import SelectAdaptivePool2d, create_conv2d
38+
from .layers import create_conv2d, create_classifier
3939
from .registry import register_model
4040

4141
__all__ = ['EfficientNet']
@@ -336,53 +336,45 @@ def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3,
336336
self.num_classes = num_classes
337337
self.num_features = num_features
338338
self.drop_rate = drop_rate
339-
self._in_chs = in_chans
340339

341340
# Stem
342341
if not fix_stem:
343342
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
344-
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
343+
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
345344
self.bn1 = norm_layer(stem_size, **norm_kwargs)
346345
self.act1 = act_layer(inplace=True)
347-
self._in_chs = stem_size
348346

349347
# Middle stages (IR/ER/DS Blocks)
350348
builder = EfficientNetBuilder(
351349
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
352350
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
353-
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
351+
self.blocks = nn.Sequential(*builder(stem_size, block_args))
354352
self.feature_info = builder.features
355-
self._in_chs = builder.in_chs
353+
head_chs = builder.in_chs
356354

357355
# Head + Pooling
358-
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
356+
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
359357
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
360358
self.act2 = act_layer(inplace=True)
361-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
362-
363-
# Classifier
364-
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
359+
self.global_pool, self.classifier = create_classifier(
360+
self.num_features, self.num_classes, pool_type=global_pool)
365361

366362
efficientnet_init_weights(self)
367363

368364
def as_sequential(self):
369365
layers = [self.conv_stem, self.bn1, self.act1]
370366
layers.extend(self.blocks)
371367
layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool])
372-
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
368+
layers.extend([nn.Dropout(self.drop_rate), self.classifier])
373369
return nn.Sequential(*layers)
374370

375371
def get_classifier(self):
376372
return self.classifier
377373

378374
def reset_classifier(self, num_classes, global_pool='avg'):
379375
self.num_classes = num_classes
380-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
381-
if num_classes:
382-
num_features = self.num_features * self.global_pool.feat_mult()
383-
self.classifier = nn.Linear(num_features, num_classes)
384-
else:
385-
self.classifier = nn.Identity()
376+
self.global_pool, self.classifier = create_classifier(
377+
self.num_features, self.num_classes, pool_type=global_pool)
386378

387379
def forward_features(self, x):
388380
x = self.conv_stem(x)
@@ -397,7 +389,6 @@ def forward_features(self, x):
397389
def forward(self, x):
398390
x = self.forward_features(x)
399391
x = self.global_pool(x)
400-
x = x.flatten(1)
401392
if self.drop_rate > 0.:
402393
x = F.dropout(x, p=self.drop_rate, training=self.training)
403394
return self.classifier(x)
@@ -417,24 +408,21 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bo
417408
super(EfficientNetFeatures, self).__init__()
418409
norm_kwargs = norm_kwargs or {}
419410
self.drop_rate = drop_rate
420-
self._in_chs = in_chans
421411

422412
# Stem
423413
if not fix_stem:
424414
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
425-
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
415+
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
426416
self.bn1 = norm_layer(stem_size, **norm_kwargs)
427417
self.act1 = act_layer(inplace=True)
428-
self._in_chs = stem_size
429418

430419
# Middle stages (IR/ER/DS Blocks)
431420
builder = EfficientNetBuilder(
432421
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
433422
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
434-
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
423+
self.blocks = nn.Sequential(*builder(stem_size, block_args))
435424
self.feature_info = FeatureInfo(builder.features, out_indices)
436425
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
437-
self._in_chs = builder.in_chs
438426

439427
efficientnet_init_weights(self)
440428

timm/models/gluon_xception.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1515
from .helpers import build_model_with_cfg
16-
from .layers import SelectAdaptivePool2d, get_padding
16+
from .layers import create_classifier, get_padding
1717
from .registry import register_model
1818

1919
__all__ = ['Xception65']
@@ -192,16 +192,14 @@ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn
192192
dict(num_chs=2048, reduction=32, module='act5'),
193193
]
194194

195-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
196-
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
195+
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
197196

198197
def get_classifier(self):
199198
return self.fc
200199

201200
def reset_classifier(self, num_classes, global_pool='avg'):
202201
self.num_classes = num_classes
203-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
204-
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
202+
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
205203

206204
def forward_features(self, x):
207205
# Entry flow
@@ -242,7 +240,7 @@ def forward_features(self, x):
242240

243241
def forward(self, x):
244242
x = self.forward_features(x)
245-
x = self.global_pool(x).flatten(1)
243+
x = self.global_pool(x)
246244
if self.drop_rate:
247245
F.dropout(x, self.drop_rate, training=self.training)
248246
x = self.fc(x)

timm/models/helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,13 @@ def adapt_model_from_string(parent_module, model_string):
187187
affine=old_module.affine, track_running_stats=True)
188188
set_layer(new_module, n, new_bn)
189189
if isinstance(old_module, nn.Linear):
190+
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
191+
num_features = state_dict[n + '.weight'][1]
190192
new_fc = nn.Linear(
191-
in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
192-
bias=old_module.bias is not None)
193+
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
193194
set_layer(new_module, n, new_fc)
195+
if hasattr(new_module, 'num_features'):
196+
new_module.num_features = num_features
194197
new_module.eval()
195198
parent_module.eval()
196199

timm/models/hrnet.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1919
from .features import FeatureInfo
2020
from .helpers import build_model_with_cfg
21-
from .layers import SelectAdaptivePool2d
21+
from .layers import create_classifier
2222
from .registry import register_model
2323
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
2424

@@ -553,8 +553,8 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_ra
553553
# Classification Head
554554
self.num_features = 2048
555555
self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels)
556-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
557-
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
556+
self.global_pool, self.classifier = create_classifier(
557+
self.num_features, self.num_classes, pool_type=global_pool)
558558
elif head == 'incre':
559559
self.num_features = 2048
560560
self.incre_modules, _, _ = self._make_head(pre_stage_channels, True)
@@ -685,12 +685,8 @@ def get_classifier(self):
685685

686686
def reset_classifier(self, num_classes, global_pool='avg'):
687687
self.num_classes = num_classes
688-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
689-
num_features = self.num_features * self.global_pool.feat_mult()
690-
if num_classes:
691-
self.classifier = nn.Linear(num_features, num_classes)
692-
else:
693-
self.classifier = nn.Identity()
688+
self.global_pool, self.classifier = create_classifier(
689+
self.num_features, self.num_classes, pool_type=global_pool)
694690

695691
def stages(self, x) -> List[torch.Tensor]:
696692
x = self.layer1(x)
@@ -726,7 +722,7 @@ def forward_features(self, x):
726722

727723
def forward(self, x):
728724
x = self.forward_features(x)
729-
x = self.global_pool(x).flatten(1)
725+
x = self.global_pool(x)
730726
if self.drop_rate > 0.:
731727
x = F.dropout(x, p=self.drop_rate, training=self.training)
732728
x = self.classifier(x)

0 commit comments

Comments
 (0)