Skip to content

Commit e62758c

Browse files
committed
More documentation updates, fix a typo
1 parent 5e333b8 commit e62758c

File tree

5 files changed

+228
-51
lines changed

5 files changed

+228
-51
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ Universal feature extraction, new models, new weights, new test sets.
1919
* Train script and loader/transform tweaks to punch through more aug arguments
2020
* README and documentation overhaul. See initial (WIP) documentation at https://rwightman.github.io/pytorch-image-models/
2121

22-
2322
### June 11, 2020
2423
Bunch of changes:
2524
* DenseNet models updated with memory efficient addition from torchvision (fixed a bug), blur pooling and deep stem additions
@@ -65,7 +64,9 @@ The work of many others is present here. I've tried to make sure all source mate
6564

6665
## Models
6766

68-
Most included models have pretrained weights. The weights are either from their original sources, ported by myself from their original framework (e.g. Tensorflow models), or trained from scratch using the included training script. A full version of the list below with source links and references can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
67+
All model architecture families include variants with pretrained weights. The are variants without any weights. Help training new or better weights is always appreciated. Here are some example [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) to get you started.
68+
69+
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
6970

7071
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
7172
* DenseNet - https://arxiv.org/abs/1608.06993
@@ -102,7 +103,7 @@ Most included models have pretrained weights. The weights are either from their
102103
* SelecSLS - https://arxiv.org/abs/1907.00837
103104
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
104105
* TResNet - https://arxiv.org/abs/2003.13630
105-
* VovNet V2 (with V1 support) - https://arxiv.org/abs/1911.06667
106+
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
106107
* Xception - https://arxiv.org/abs/1610.02357
107108
* Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611
108109
* Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611
@@ -113,12 +114,12 @@ Several (less common) features that I often utilize in my projects are included.
113114

114115
* All models have a common default configuration interface and API for
115116
* accessing/changing the classifier - `get_classifier` and `reset_classifier`
116-
* doing a forward pass on just the features - `forward_features`
117+
* doing a forward pass on just the features - `forward_features` (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
117118
* these makes it easy to write consistent network wrappers that work with any of the models
118-
* All models support multi-scale feature map extraction (feature pyramids) via create_model
119+
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
119120
* `create_model(name, features_only=True, out_indices=..., output_stride=...)`
120-
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level. Most models start with stride 2 features (`C1`) at index 0 and end with `C5` at index 4. Some models start with stride 1 or 4 and end with 6 (stride 64).
121-
* `output_stride` creation arg controls output stride of the network, most networks are stride 32 by default. Dilated convs are used to limit the output stride. Not all networks support this.
121+
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
122+
* `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
122123
* feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
123124
* All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
124125
* High performance [reference training, validation, and inference scripts](https://rwightman.github.io/pytorch-image-models/scripts/) that work in several process/GPU modes:

docs/feature_extraction.md

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Feature Extraction
2+
3+
All of the models in `timm` have consistent mechanisms for obtaining various types of features from the model for tasks besides classification.
4+
5+
## Penultimate Layer Features (Pre-Classifier Features)
6+
7+
The features from the penultimate model layer can be obtained in severay ways without requiring model surgery (although feel free to do surgery). One must first decide if they want pooled or un-pooled features.
8+
9+
### Unpooled
10+
11+
There are three ways to obtain unpooled features.
12+
13+
Without modifying the network, one can call `model.forward_features(input)` on any model instead of the usual `model(input)`. This will bypass the head classifier and global pooling for networks.
14+
15+
If one wants to explicitly modify the network to return unpooled features, they can either create the model without a classifier and pooling, or remove it later. Both paths remove the parameters associated with the classifier from the network.
16+
17+
#### forward_features()
18+
```python hl_lines="3 6"
19+
import torch
20+
import timm
21+
m = timm.create_model('xception41', pretrained=True)
22+
o = m(torch.randn(2, 3, 299, 299))
23+
print(f'Original shape: {o.shape}')
24+
o = m.forward_features(torch.randn(2, 3, 299, 299))
25+
print(f'Unpooled shape: {o.shape}')
26+
```
27+
Output:
28+
```text
29+
Original shape: torch.Size([2, 1000])
30+
Unpooled shape: torch.Size([2, 2048, 10, 10])
31+
```
32+
33+
#### Create with no classifier and pooling
34+
```python hl_lines="3"
35+
import torch
36+
import timm
37+
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
38+
o = m(torch.randn(2, 3, 224, 224))
39+
print(f'Unpooled shape: {o.shape}')
40+
```
41+
Output:
42+
```text
43+
Unpooled shape: torch.Size([2, 2048, 7, 7])
44+
```
45+
46+
#### Remove it later
47+
```python hl_lines="3 6"
48+
import torch
49+
import timm
50+
m = timm.create_model('densenet121', pretrained=True)
51+
o = m(torch.randn(2, 3, 224, 224))
52+
print(f'Original shape: {o.shape}')
53+
m.reset_classifier(0, '')
54+
o = m(torch.randn(2, 3, 224, 224))
55+
print(f'Unpooled shape: {o.shape}')
56+
```
57+
Output:
58+
```text
59+
Original shape: torch.Size([2, 1000])
60+
Unpooled shape: torch.Size([2, 1024, 7, 7])
61+
```
62+
63+
### Pooled
64+
65+
To modify the network to return pooled features, one can use `forward_features()` and pool/flatten the result themselves, or modify the network like above but keep pooling intact.
66+
67+
#### Create with no classifier
68+
```python hl_lines="3"
69+
import torch
70+
import timm
71+
m = timm.create_model('resnet50', pretrained=True, num_classes=0)
72+
o = m(torch.randn(2, 3, 224, 224))
73+
print(f'Pooled shape: {o.shape}')
74+
```
75+
Output:
76+
```text
77+
Pooled shape: torch.Size([2, 2048])
78+
```
79+
80+
#### Remove it later
81+
```python hl_lines="3 6"
82+
import torch
83+
import timm
84+
m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
85+
o = m(torch.randn(2, 3, 224, 224))
86+
print(f'Original shape: {o.shape}')
87+
m.reset_classifier(0)
88+
o = m(torch.randn(2, 3, 224, 224))
89+
print(f'Pooled shape: {o.shape}')
90+
```
91+
Output:
92+
```text
93+
Pooled shape: torch.Size([2, 1024])
94+
```
95+
96+
97+
## Multi-scale Feature Maps (Feature Pyramid)
98+
99+
Object detection, segmentation, keypoint, and a variety of dense pixel tasks require access to feature maps from the backbone network at multiple scales. This is often done by modifying the original classification network. Since each network varies quite a bit in structure, it's not uncommon to see only a few backbones supported in any given obj detection or segmentation library.
100+
101+
`timm` allows a consistent interface for creating any of the included models as feature backbones that output feature maps for selected levels.
102+
103+
A feature backbone can be created by adding the argument `features_only=True` to any `create_model` call. By default 5 strides will be output from most models (not all have that many), with the first starting at 2 (some start at 1 or 4).
104+
105+
### Create a feature map extraction model
106+
```python hl_lines="3"
107+
import torch
108+
import timm
109+
m = timm.create_model('resnest26d', features_only=True, pretrained=True)
110+
o = m(torch.randn(2, 3, 224, 224))
111+
for x in o:
112+
print(x.shape)
113+
```
114+
Output:
115+
```text
116+
torch.Size([2, 64, 112, 112])
117+
torch.Size([2, 256, 56, 56])
118+
torch.Size([2, 512, 28, 28])
119+
torch.Size([2, 1024, 14, 14])
120+
torch.Size([2, 2048, 7, 7])
121+
```
122+
123+
### Query the feature information
124+
125+
After a feature backbone has been created, it can be queried to provide channel or resolution reduction information to the downstream heads without requiring static config or hardcoded constants. The `.feature_info` attribute is a class encapsulating the information about the feature extraction points.
126+
127+
```python hl_lines="3 4"
128+
import torch
129+
import timm
130+
m = timm.create_model('regnety_032', features_only=True, pretrained=True)
131+
print(f'Feature channels: {m.feature_info.channels()}')
132+
o = m(torch.randn(2, 3, 224, 224))
133+
for x in o:
134+
print(x.shape)
135+
```
136+
Output:
137+
```text
138+
Feature channels: [32, 72, 216, 576, 1512]
139+
torch.Size([2, 32, 112, 112])
140+
torch.Size([2, 72, 56, 56])
141+
torch.Size([2, 216, 28, 28])
142+
torch.Size([2, 576, 14, 14])
143+
torch.Size([2, 1512, 7, 7])
144+
```
145+
146+
### Select specific feature levels or limit the stride
147+
148+
There are to additional creation arguments impacting the output features.
149+
150+
* `out_indices` selects which indices to output
151+
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW)
152+
153+
`out_indices` is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most models, index 0 is the stride 2 features, and index 4 is stride 32.
154+
155+
`output_stride` is achieved by converting layers to use dilated convolutions. Doing so is not always straightforward, some networks only support `output_stride=32`.
156+
157+
```python hl_lines="3 4 5"
158+
import torch
159+
import timm
160+
m = timm.create_model('ecaresnet101d', features_only=True, output_stride=8, out_indices=(2, 4), pretrained=True)
161+
print(f'Feature channels: {m.feature_info.channels()}')
162+
print(f'Feature reduction: {m.feature_info.reduction()}')
163+
o = m(torch.randn(2, 3, 320, 320))
164+
for x in o:
165+
print(x.shape)
166+
```
167+
Output:
168+
```text
169+
Feature channels: [512, 2048]
170+
Feature reduction: [8, 8]
171+
torch.Size([2, 512, 40, 40])
172+
torch.Size([2, 2048, 40, 40])
173+
```

docs/results.md

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,54 @@ CSV files containing an ImageNet-1K validation and OOD test set validation resul
55
## Self-trained Weights
66
I've leveraged the training scripts in this repository to train a few of the models with to good levels of performance.
77

8-
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size |
8+
|Model | Acc@1 (Err) | Acc@5 (Err) | Param # (M) | Interpolation | Image Size |
99
|---|---|---|---|---|---|
10-
| efficientnet_b3a | 81.874 (18.126) | 95.840 (4.160) | 12.23M | bicubic | 320 (1.0 crop) |
11-
| efficientnet_b3 | 81.498 (18.502) | 95.718 (4.282) | 12.23M | bicubic | 300 |
12-
| skresnext50d_32x4d | 81.278 (18.722) | 95.366 (4.634) | 27.5M | bicubic | 288 (1.0 crop) |
13-
| efficientnet_b2a | 80.608 (19.392) | 95.310 (4.690) | 9.11M | bicubic | 288 (1.0 crop) |
14-
| mixnet_xl | 80.478 (19.522) | 94.932 (5.068) | 11.90M | bicubic | 224 |
15-
| efficientnet_b2 | 80.402 (19.598) | 95.076 (4.924) | 9.11M | bicubic | 260 |
16-
| skresnext50d_32x4d | 80.156 (19.844) | 94.642 (5.358) | 27.5M | bicubic | 224 |
17-
| resnext50_32x4d | 79.762 (20.238) | 94.600 (5.400) | 25M | bicubic | 224 |
18-
| resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1M | bicubic | 224 |
19-
| ese_vovnet39b | 79.320 (20.680) | 94.710 (5.290) | 24.6M | bicubic | 224 |
20-
| resnetblur50 | 79.290 (20.710) | 94.632 (5.368) | 25.6M | bicubic | 224 |
21-
| resnet50 | 79.038 (20.962) | 94.390 (5.610) | 25.6M | bicubic | 224 |
22-
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33M | bicubic | 224 |
23-
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.79M | bicubic | 240 |
24-
| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44M | bicubic | 224 |
25-
| seresnext26t_32x4d | 77.998 (22.002) | 93.708 (6.292) | 16.8M | bicubic | 224 |
26-
| seresnext26tn_32x4d | 77.986 (22.014) | 93.746 (6.254) | 16.8M | bicubic | 224 |
27-
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.29M | bicubic | 224 |
28-
| seresnext26d_32x4d | 77.602 (22.398) | 93.608 (6.392) | 16.8M | bicubic | 224 |
29-
| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8M | bicubic | 224 |
30-
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01M | bicubic | 224 |
31-
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | 224 |
32-
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
33-
| ese_vovnet19b_dw | 76.798 (23.202) | 93.268 (6.732) | 6.5M | bicubic | 224 |
34-
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
35-
| densenetblur121d | 76.576 (23.424) | 93.190 (6.810) | 8.0M | bicubic | 224 |
36-
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1M | bicubic | 224 |
37-
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
38-
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |
39-
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
40-
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | 224 |
41-
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
42-
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
43-
| resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22M | bilinear | 224 |
44-
| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5M | bicubic | 224 |
45-
| seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22M | bilinear | 224 |
46-
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.38M | bicubic | 224 |
47-
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.42M | bilinear | 224 |
48-
| skresnet18 | 73.038 (26.962) | 91.168 (8.832) | 11.9M | bicubic | 224 |
49-
| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5M | bicubic | 224 |
50-
| seresnet18 | 71.742 (28.258) | 90.334 (9.666) | 11.8M | bicubic | 224 |
10+
| efficientnet_b3a | 81.874 (18.126) | 95.840 (4.160) | 12.23 | bicubic | 320 (1.0 crop) |
11+
| efficientnet_b3 | 81.498 (18.502) | 95.718 (4.282) | 12.23 | bicubic | 300 |
12+
| skresnext50d_32x4d | 81.278 (18.722) | 95.366 (4.634) | 27.5 | bicubic | 288 (1.0 crop) |
13+
| efficientnet_b2a | 80.608 (19.392) | 95.310 (4.690) | 9.11 | bicubic | 288 (1.0 crop) |
14+
| mixnet_xl | 80.478 (19.522) | 94.932 (5.068) | 11.90 | bicubic | 224 |
15+
| efficientnet_b2 | 80.402 (19.598) | 95.076 (4.924) | 9.11 | bicubic | 260 |
16+
| skresnext50d_32x4d | 80.156 (19.844) | 94.642 (5.358) | 27.5 | bicubic | 224 |
17+
| resnext50_32x4d | 79.762 (20.238) | 94.600 (5.400) | 25 | bicubic | 224 |
18+
| resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1 | bicubic | 224 |
19+
| ese_vovnet39b | 79.320 (20.680) | 94.710 (5.290) | 24.6 | bicubic | 224 |
20+
| resnetblur50 | 79.290 (20.710) | 94.632 (5.368) | 25.6 | bicubic | 224 |
21+
| resnet50 | 79.038 (20.962) | 94.390 (5.610) | 25.6 | bicubic | 224 |
22+
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | bicubic | 224 |
23+
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.79 | bicubic | 240 |
24+
| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | bicubic | 224 |
25+
| seresnext26t_32x4d | 77.998 (22.002) | 93.708 (6.292) | 16.8 | bicubic | 224 |
26+
| seresnext26tn_32x4d | 77.986 (22.014) | 93.746 (6.254) | 16.8 | bicubic | 224 |
27+
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.29 | bicubic | 224 |
28+
| seresnext26d_32x4d | 77.602 (22.398) | 93.608 (6.392) | 16.8 | bicubic | 224 |
29+
| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | bicubic | 224 |
30+
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | bicubic | 224 |
31+
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8 | bicubic | 224 |
32+
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2 | bicubic | 224 |
33+
| ese_vovnet19b_dw | 76.798 (23.202) | 93.268 (6.732) | 6.5 | bicubic | 224 |
34+
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16 | bicubic | 224 |
35+
| densenetblur121d | 76.576 (23.424) | 93.190 (6.810) | 8.0 | bicubic | 224 |
36+
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | bicubic | 224 |
37+
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | bicubic | 224 |
38+
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | bicubic | 224 |
39+
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | bicubic | 224 |
40+
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89 | bicubic | 224 |
41+
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16 | bicubic | 224 |
42+
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | bilinear | 224 |
43+
| resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22 | bilinear | 224 |
44+
| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | bicubic | 224 |
45+
| seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22 | bilinear | 224 |
46+
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.38 | bicubic | 224 |
47+
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.42 | bilinear | 224 |
48+
| skresnet18 | 73.038 (26.962) | 91.168 (8.832) | 11.9 | bicubic | 224 |
49+
| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | bicubic | 224 |
50+
| seresnet18 | 71.742 (28.258) | 90.334 (9.666) | 11.8 | bicubic | 224 |
5151

5252
## Ported Weights
5353
For the models below, the model code and weight porting from Tensorflow or MXNet Gluon to Pytorch was done by myself. There are weights/models ported by others included in this repository, they are not listed below.
5454

55-
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size |
55+
| Model | Acc@1 (Err) | Acc@5 (Err) | Param # (M) | Interpolation | Image Size |
5656
|---|---|---|---|---|---|
5757
| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 |
5858
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 |

mkdocs.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ nav:
88
- results.md
99
- scripts.md
1010
- training_hparam_examples.md
11+
- feature_extraction.md
1112
- changes.md
1213
- archived_changes.md
1314
theme:
@@ -16,6 +17,8 @@ theme:
1617
tabs: false
1718
extra_javascript:
1819
- 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML'
20+
- https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js
21+
- javascripts/tables.js
1922
markdown_extensions:
2023
- codehilite:
2124
linenums: true

timm/data/transforms_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def transforms_noaug_train(
2222
std=IMAGENET_DEFAULT_STD,
2323
):
2424
if interpolation == 'random':
25-
# random interpolation no supported with no-aug
25+
# random interpolation not supported with no-aug
2626
interpolation = 'bilinear'
2727
tfl = [
2828
transforms.Resize(img_size, _pil_interp(interpolation)),

0 commit comments

Comments
 (0)