Skip to content

Commit f6b5660

Browse files
author
Yusuke Uchida
committed
fix test_model_default_cfgs
1 parent 078a51d commit f6b5660

File tree

9 files changed

+56
-31
lines changed

9 files changed

+56
-31
lines changed

tests/test_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def test_model_backward(model_name, batch_size):
6969

7070
@pytest.mark.timeout(120)
7171
@pytest.mark.parametrize('model_name', list_models())
72+
#@pytest.mark.parametrize('model_name', ["xception41"])
7273
@pytest.mark.parametrize('batch_size', [1])
7374
def test_model_default_cfgs(model_name, batch_size):
7475
"""Run a single forward pass with each model"""
@@ -106,8 +107,8 @@ def test_model_default_cfgs(model_name, batch_size):
106107
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
107108

108109
# check classifier and first convolution names match those in default_cfg
109-
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
110-
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'
110+
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
111+
assert first_conv + ".weight" in state_dict.keys(), f'{first_conv} not in model params'
111112

112113

113114
if 'GITHUB_ACTIONS' not in os.environ:

timm/models/gluon_resnet.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,32 @@ def _cfg(url='', **kwargs):
2828
'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'),
2929
'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'),
3030
'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'),
31-
'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth'),
32-
'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth'),
33-
'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth'),
34-
'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth'),
35-
'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth'),
36-
'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth'),
37-
'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth'),
38-
'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth'),
39-
'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth'),
31+
'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth',
32+
first_conv='conv1.0'),
33+
'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth',
34+
first_conv='conv1.0'),
35+
'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth',
36+
first_conv='conv1.0'),
37+
'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth',
38+
first_conv='conv1.0'),
39+
'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth',
40+
first_conv='conv1.0'),
41+
'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth',
42+
first_conv='conv1.0'),
43+
'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth',
44+
first_conv='conv1.0'),
45+
'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth',
46+
first_conv='conv1.0'),
47+
'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth',
48+
first_conv='conv1.0'),
4049
'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
4150
'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
4251
'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
4352
'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
4453
'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
4554
'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
46-
'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'),
55+
'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth',
56+
first_conv='conv1.0'),
4757
}
4858

4959

timm/models/inception_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _cfg(url='', **kwargs):
1919
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
2020
'crop_pct': 0.875, 'interpolation': 'bicubic',
2121
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
22-
'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc',
22+
'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc',
2323
**kwargs
2424
}
2525

timm/models/resnest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _cfg(url='', **kwargs):
2222
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
2323
'crop_pct': 0.875, 'interpolation': 'bilinear',
2424
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
25-
'first_conv': 'conv1', 'classifier': 'fc',
25+
'first_conv': 'conv1.0', 'classifier': 'fc',
2626
**kwargs
2727
}
2828

timm/models/resnet.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ def _cfg(url='', **kwargs):
4242
interpolation='bicubic'),
4343
'resnet26d': _cfg(
4444
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
45-
interpolation='bicubic'),
45+
interpolation='bicubic',
46+
first_conv='conv1.0'),
4647
'resnet50': _cfg(
4748
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth',
4849
interpolation='bicubic'),
4950
'resnet50d': _cfg(
5051
url='',
51-
interpolation='bicubic'),
52+
interpolation='bicubic',
53+
first_conv='conv1.0'),
5254
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
5355
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
5456
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
@@ -62,7 +64,8 @@ def _cfg(url='', **kwargs):
6264
interpolation='bicubic'),
6365
'resnext50d_32x4d': _cfg(
6466
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth',
65-
interpolation='bicubic'),
67+
interpolation='bicubic',
68+
first_conv='conv1.0'),
6669
'resnext101_32x4d': _cfg(url=''),
6770
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
6871
'resnext101_64x4d': _cfg(url=''),
@@ -118,7 +121,8 @@ def _cfg(url='', **kwargs):
118121
interpolation='bicubic'),
119122
'seresnet50tn': _cfg(
120123
url='',
121-
interpolation='bicubic'),
124+
interpolation='bicubic',
125+
first_conv='conv1.0'),
122126
'seresnet101': _cfg(
123127
url='',
124128
interpolation='bicubic'),
@@ -132,13 +136,16 @@ def _cfg(url='', **kwargs):
132136
interpolation='bicubic'),
133137
'seresnext26d_32x4d': _cfg(
134138
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
135-
interpolation='bicubic'),
139+
interpolation='bicubic',
140+
first_conv='conv1.0'),
136141
'seresnext26t_32x4d': _cfg(
137142
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26t_32x4d-361bc1c4.pth',
138-
interpolation='bicubic'),
143+
interpolation='bicubic',
144+
first_conv='conv1.0'),
139145
'seresnext26tn_32x4d': _cfg(
140146
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
141-
interpolation='bicubic'),
147+
interpolation='bicubic',
148+
first_conv='conv1.0'),
142149
'seresnext50_32x4d': _cfg(
143150
interpolation='bicubic'),
144151
'seresnext101_32x4d': _cfg(
@@ -149,7 +156,8 @@ def _cfg(url='', **kwargs):
149156
interpolation='bicubic'),
150157
'senet154': _cfg(
151158
url='',
152-
interpolation='bicubic'),
159+
interpolation='bicubic',
160+
first_conv='conv1.0'),
153161

154162
# Efficient Channel Attention ResNets
155163
'ecaresnet18': _cfg(),
@@ -159,21 +167,26 @@ def _cfg(url='', **kwargs):
159167
interpolation='bicubic'),
160168
'ecaresnet50d': _cfg(
161169
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
162-
interpolation='bicubic'),
170+
interpolation='bicubic',
171+
first_conv='conv1.0'),
163172
'ecaresnet50d_pruned': _cfg(
164173
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
165-
interpolation='bicubic'),
174+
interpolation='bicubic',
175+
first_conv='conv1.0'),
166176
'ecaresnet101d': _cfg(
167177
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
168-
interpolation='bicubic'),
178+
interpolation='bicubic',
179+
first_conv='conv1.0'),
169180
'ecaresnet101d_pruned': _cfg(
170181
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
171-
interpolation='bicubic'),
182+
interpolation='bicubic',
183+
first_conv='conv1.0'),
172184

173185
# Efficient Channel Attention ResNeXts
174186
'ecaresnext26tn_32x4d': _cfg(
175187
url='',
176-
interpolation='bicubic'),
188+
interpolation='bicubic',
189+
first_conv='conv1.0'),
177190
'ecaresnext50_32x4d': _cfg(
178191
url='',
179192
interpolation='bicubic'),

timm/models/selecsls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _cfg(url='', **kwargs):
2929
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
3030
'crop_pct': 0.875, 'interpolation': 'bilinear',
3131
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
32-
'first_conv': 'stem', 'classifier': 'fc',
32+
'first_conv': 'stem.0', 'classifier': 'fc',
3333
**kwargs
3434
}
3535

timm/models/sknet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def _cfg(url='', **kwargs):
3636
'skresnet34': _cfg(
3737
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
3838
'skresnet50': _cfg(),
39-
'skresnet50d': _cfg(),
39+
'skresnet50d': _cfg(
40+
first_conv='conv1.0'),
4041
'skresnext50_32x4d': _cfg(
4142
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'),
4243
}

timm/models/tresnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _cfg(url='', **kwargs):
2525
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
2626
'crop_pct': 0.875, 'interpolation': 'bilinear',
2727
'mean': (0, 0, 0), 'std': (1, 1, 1),
28-
'first_conv': 'body.conv1', 'classifier': 'head.fc',
28+
'first_conv': 'body.conv1.0', 'classifier': 'head.fc',
2929
**kwargs
3030
}
3131

timm/models/xception_aligned.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _cfg(url='', **kwargs):
2525
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
2626
'crop_pct': 0.903, 'interpolation': 'bicubic',
2727
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
28-
'first_conv': 'stem.0', 'classifier': 'head.fc',
28+
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
2929
**kwargs
3030
}
3131

0 commit comments

Comments
 (0)