Skip to content

Commit fa28067

Browse files
committed
Add more augmentation arguments, including a no_aug disable flag. Fix huggingface#209
1 parent e3f58fc commit fa28067

File tree

3 files changed

+108
-22
lines changed

3 files changed

+108
-22
lines changed

timm/data/loader.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,15 @@ def create_loader(
131131
batch_size,
132132
is_training=False,
133133
use_prefetcher=True,
134+
no_aug=False,
134135
re_prob=0.,
135136
re_mode='const',
136137
re_count=1,
137138
re_split=False,
139+
scale=None,
140+
ratio=None,
141+
hflip=0.5,
142+
vflip=0.,
138143
color_jitter=0.4,
139144
auto_augment=None,
140145
num_aug_splits=0,
@@ -158,6 +163,11 @@ def create_loader(
158163
input_size,
159164
is_training=is_training,
160165
use_prefetcher=use_prefetcher,
166+
no_aug=no_aug,
167+
scale=scale,
168+
ratio=ratio,
169+
hflip=hflip,
170+
vflip=vflip,
161171
color_jitter=color_jitter,
162172
auto_augment=auto_augment,
163173
interpolation=interpolation,
@@ -200,12 +210,13 @@ def create_loader(
200210
drop_last=is_training,
201211
)
202212
if use_prefetcher:
213+
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
203214
loader = PrefetchLoader(
204215
loader,
205216
mean=mean,
206217
std=std,
207218
fp16=fp16,
208-
re_prob=re_prob if is_training else 0.,
219+
re_prob=prefetch_re_prob,
209220
re_mode=re_mode,
210221
re_count=re_count,
211222
re_num_splits=re_num_splits

timm/data/transforms_factory.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,39 @@
1414
from timm.data.random_erasing import RandomErasing
1515

1616

17+
def transforms_noaug_train(
18+
img_size=224,
19+
interpolation='bilinear',
20+
use_prefetcher=False,
21+
mean=IMAGENET_DEFAULT_MEAN,
22+
std=IMAGENET_DEFAULT_STD,
23+
):
24+
if interpolation == 'random':
25+
# random interpolation no supported with no-aug
26+
interpolation = 'bilinear'
27+
tfl = [
28+
transforms.Resize(img_size, _pil_interp(interpolation)),
29+
transforms.CenterCrop(img_size)
30+
]
31+
if use_prefetcher:
32+
# prefetcher and collate will handle tensor conversion and norm
33+
tfl += [ToNumpy()]
34+
else:
35+
tfl += [
36+
transforms.ToTensor(),
37+
transforms.Normalize(
38+
mean=torch.tensor(mean),
39+
std=torch.tensor(std))
40+
]
41+
return transforms.Compose(tfl)
42+
43+
1744
def transforms_imagenet_train(
1845
img_size=224,
19-
scale=(0.08, 1.0),
46+
scale=None,
47+
ratio=None,
48+
hflip=0.5,
49+
vflip=0.,
2050
color_jitter=0.4,
2151
auto_augment=None,
2252
interpolation='random',
@@ -36,11 +66,14 @@ def transforms_imagenet_train(
3666
* a portion of the data through the secondary transform
3767
* normalizes and converts the branches above with the third, final transform
3868
"""
69+
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
70+
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
3971
primary_tfl = [
40-
RandomResizedCropAndInterpolation(
41-
img_size, scale=scale, interpolation=interpolation),
42-
transforms.RandomHorizontalFlip()
43-
]
72+
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
73+
if hflip > 0.:
74+
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
75+
if vflip > 0.:
76+
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
4477

4578
secondary_tfl = []
4679
if auto_augment:
@@ -135,6 +168,11 @@ def create_transform(
135168
input_size,
136169
is_training=False,
137170
use_prefetcher=False,
171+
no_aug=False,
172+
scale=None,
173+
ratio=None,
174+
hflip=0.5,
175+
vflip=0.,
138176
color_jitter=0.4,
139177
auto_augment=None,
140178
interpolation='bilinear',
@@ -159,9 +197,21 @@ def create_transform(
159197
transform = TfPreprocessTransform(
160198
is_training=is_training, size=img_size, interpolation=interpolation)
161199
else:
162-
if is_training:
200+
if is_training and no_aug:
201+
assert not separate, "Cannot perform split augmentation with no_aug"
202+
transform = transforms_noaug_train(
203+
img_size,
204+
interpolation=interpolation,
205+
use_prefetcher=use_prefetcher,
206+
mean=mean,
207+
std=std)
208+
elif is_training:
163209
transform = transforms_imagenet_train(
164210
img_size,
211+
scale=scale,
212+
ratio=ratio,
213+
hflip=hflip,
214+
vflip=vflip,
165215
color_jitter=color_jitter,
166216
auto_augment=auto_augment,
167217
interpolation=interpolation,

train.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252

5353
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
54+
5455
# Dataset / Model parameters
5556
parser.add_argument('data', metavar='DIR',
5657
help='path to dataset')
@@ -82,16 +83,7 @@
8283
help='input batch size for training (default: 32)')
8384
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
8485
help='ratio of validation batch size to training batch size (default: 1)')
85-
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
86-
help='Dropout rate (default: 0.)')
87-
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
88-
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
89-
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
90-
help='Drop path rate (default: None)')
91-
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
92-
help='Drop block rate (default: None)')
93-
parser.add_argument('--jsd', action='store_true', default=False,
94-
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
86+
9587
# Optimizer parameters
9688
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
9789
help='Optimizer (default: "sgd"')
@@ -101,6 +93,7 @@
10193
help='SGD momentum (default: 0.9)')
10294
parser.add_argument('--weight-decay', type=float, default=0.0001,
10395
help='weight decay (default: 0.0001)')
96+
10497
# Learning rate schedule parameters
10598
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
10699
help='LR scheduler (default: "step"')
@@ -134,13 +127,26 @@
134127
help='patience epochs for Plateau LR scheduler (default: 10')
135128
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
136129
help='LR decay rate (default: 0.1)')
137-
# Augmentation parameters
130+
131+
# Augmentation & regularization parameters
132+
parser.add_argument('--no-aug', action='store_true', default=False,
133+
help='Disable all training augmentation, override other train aug args')
134+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
135+
help='Random resize scale (default: 0.08 1.0)')
136+
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
137+
help='Random resize aspect ratio (default: 0.75 1.33)')
138+
parser.add_argument('--hflip', type=float, default=0.5,
139+
help='Horizontal flip training aug probability')
140+
parser.add_argument('--vflip', type=float, default=0.,
141+
help='Vertical flip training aug probability')
138142
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
139143
help='Color jitter factor (default: 0.4)')
140144
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
141145
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
142146
parser.add_argument('--aug-splits', type=int, default=0,
143147
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
148+
parser.add_argument('--jsd', action='store_true', default=False,
149+
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
144150
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
145151
help='Random erase prob (default: 0.)')
146152
parser.add_argument('--remode', type=str, default='const',
@@ -150,13 +156,22 @@
150156
parser.add_argument('--resplit', action='store_true', default=False,
151157
help='Do not random erase first (clean) augmentation split')
152158
parser.add_argument('--mixup', type=float, default=0.0,
153-
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
159+
help='Mixup alpha, mixup enabled if > 0. (default: 0.)')
154160
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
155-
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
161+
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
156162
parser.add_argument('--smoothing', type=float, default=0.1,
157-
help='label smoothing (default: 0.1)')
163+
help='Label smoothing (default: 0.1)')
158164
parser.add_argument('--train-interpolation', type=str, default='random',
159165
help='Training interpolation (random, bilinear, bicubic default: "random")')
166+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
167+
help='Dropout rate (default: 0.)')
168+
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
169+
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
170+
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
171+
help='Drop path rate (default: None)')
172+
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
173+
help='Drop block rate (default: None)')
174+
160175
# Batch norm parameters (only works with gen_efficientnet based models currently)
161176
parser.add_argument('--bn-tf', action='store_true', default=False,
162177
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
@@ -170,13 +185,15 @@
170185
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
171186
parser.add_argument('--split-bn', action='store_true',
172187
help='Enable separate BN layers per augmentation split.')
188+
173189
# Model Exponential Moving Average
174190
parser.add_argument('--model-ema', action='store_true', default=False,
175191
help='Enable tracking moving average of model weights')
176192
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
177193
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
178194
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
179195
help='decay factor for model weights moving average (default: 0.9998)')
196+
180197
# Misc
181198
parser.add_argument('--seed', type=int, default=42, metavar='S',
182199
help='random seed (default: 42)')
@@ -378,20 +395,28 @@ def main():
378395
if num_aug_splits > 1:
379396
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
380397

398+
train_interpolation = args.train_interpolation
399+
if args.no_aug or not train_interpolation:
400+
train_interpolation = data_config['interpolation']
381401
loader_train = create_loader(
382402
dataset_train,
383403
input_size=data_config['input_size'],
384404
batch_size=args.batch_size,
385405
is_training=True,
386406
use_prefetcher=args.prefetcher,
407+
no_aug=args.no_aug,
387408
re_prob=args.reprob,
388409
re_mode=args.remode,
389410
re_count=args.recount,
390411
re_split=args.resplit,
412+
scale=args.scale,
413+
ratio=args.ratio,
414+
hflip=args.hflip,
415+
vflip=args.vflip,
391416
color_jitter=args.color_jitter,
392417
auto_augment=args.aa,
393418
num_aug_splits=num_aug_splits,
394-
interpolation=args.train_interpolation,
419+
interpolation=train_interpolation,
395420
mean=data_config['mean'],
396421
std=data_config['std'],
397422
num_workers=args.workers,

0 commit comments

Comments
 (0)