Skip to content

Commit b423bc8

Browse files
authored
Merge pull request huggingface#218 from rwightman/cutmix
CutMix + MixUp overhaul
2 parents 0f5d9d8 + 8c9814e commit b423bc8

File tree

4 files changed

+259
-47
lines changed

4 files changed

+259
-47
lines changed

timm/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .transforms import *
55
from .loader import create_loader
66
from .transforms_factory import create_transform
7-
from .mixup import mixup_batch, FastCollateMixup
7+
from .mixup import Mixup, FastCollateMixup
88
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
99
rand_augment_transform, auto_augment_transform
1010
from .real_labels import RealLabelsImagenet

timm/data/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __getitem__(self, index):
9292
return img, target
9393

9494
def __len__(self):
95-
return len(self.imgs)
95+
return len(self.samples)
9696

9797
def filenames(self, indices=[], basename=False):
9898
if indices:

timm/data/mixup.py

Lines changed: 222 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1-
""" Mixup
2-
Paper: `mixup: Beyond Empirical Risk Minimization` - https://arxiv.org/abs/1710.09412
1+
""" Mixup and Cutmix
2+
3+
Papers:
4+
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5+
6+
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7+
8+
Code Reference:
9+
CutMix: https://github.com/clovaai/CutMix-PyTorch
310
411
Hacked together by / Copyright 2020 Ross Wightman
512
"""
@@ -17,40 +24,230 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
1724
on_value = 1. - smoothing + off_value
1825
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
1926
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
20-
return lam*y1 + (1. - lam)*y2
27+
return y1 * lam + y2 * (1. - lam)
28+
29+
30+
def rand_bbox(img_shape, lam, margin=0., count=None):
31+
""" Standard CutMix bounding-box
32+
Generates a random square bbox based on lambda value. This impl includes
33+
support for enforcing a border margin as percent of bbox dimensions.
34+
35+
Args:
36+
img_shape (tuple): Image shape as tuple
37+
lam (float): Cutmix lambda value
38+
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39+
count (int): Number of bbox to generate
40+
"""
41+
ratio = np.sqrt(1 - lam)
42+
img_h, img_w = img_shape[-2:]
43+
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44+
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45+
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46+
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47+
yl = np.clip(cy - cut_h // 2, 0, img_h)
48+
yh = np.clip(cy + cut_h // 2, 0, img_h)
49+
xl = np.clip(cx - cut_w // 2, 0, img_w)
50+
xh = np.clip(cx + cut_w // 2, 0, img_w)
51+
return yl, yh, xl, xh
52+
53+
54+
def rand_bbox_minmax(img_shape, minmax, count=None):
55+
""" Min-Max CutMix bounding-box
56+
Inspired by Darknet cutmix impl, generates a random rectangular bbox
57+
based on min/max percent values applied to each dimension of the input image.
2158
59+
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
2260
23-
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
24-
lam = 1.
25-
if not disable:
26-
lam = np.random.beta(alpha, alpha)
27-
input = input.mul(lam).add_(1 - lam, input.flip(0))
28-
target = mixup_target(target, num_classes, lam, smoothing)
29-
return input, target
61+
Args:
62+
img_shape (tuple): Image shape as tuple
63+
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64+
count (int): Number of bbox to generate
65+
"""
66+
assert len(minmax) == 2
67+
img_h, img_w = img_shape[-2:]
68+
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69+
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70+
yl = np.random.randint(0, img_h - cut_h, size=count)
71+
xl = np.random.randint(0, img_w - cut_w, size=count)
72+
yu = yl + cut_h
73+
xu = xl + cut_w
74+
return yl, yu, xl, xu
3075

3176

32-
class FastCollateMixup:
77+
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78+
""" Generate bbox and apply lambda correction.
79+
"""
80+
if ratio_minmax is not None:
81+
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82+
else:
83+
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84+
if correct_lam or ratio_minmax is not None:
85+
bbox_area = (yu - yl) * (xu - xl)
86+
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87+
return (yl, yu, xl, xu), lam
3388

34-
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
89+
90+
class Mixup:
91+
""" Mixup/Cutmix that applies different params to each element or whole batch
92+
93+
Args:
94+
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95+
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96+
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97+
prob (float): probability of applying mixup or cutmix per batch or element
98+
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99+
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
100+
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101+
label_smoothing (float): apply label smoothing to the mixed target tensor
102+
num_classes (int): number of classes for target
103+
"""
104+
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105+
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
35106
self.mixup_alpha = mixup_alpha
107+
self.cutmix_alpha = cutmix_alpha
108+
self.cutmix_minmax = cutmix_minmax
109+
if self.cutmix_minmax is not None:
110+
assert len(self.cutmix_minmax) == 2
111+
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112+
self.cutmix_alpha = 1.0
113+
self.mix_prob = prob
114+
self.switch_prob = switch_prob
36115
self.label_smoothing = label_smoothing
37116
self.num_classes = num_classes
38-
self.mixup_enabled = True
117+
self.elementwise = elementwise
118+
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119+
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
39120

40-
def __call__(self, batch):
41-
batch_size = len(batch)
42-
lam = 1.
121+
def _params_per_elem(self, batch_size):
122+
lam = np.ones(batch_size, dtype=np.float32)
123+
use_cutmix = np.zeros(batch_size, dtype=np.bool)
43124
if self.mixup_enabled:
44-
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
125+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126+
use_cutmix = np.random.rand(batch_size) < self.switch_prob
127+
lam_mix = np.where(
128+
use_cutmix,
129+
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130+
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131+
elif self.mixup_alpha > 0.:
132+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133+
elif self.cutmix_alpha > 0.:
134+
use_cutmix = np.ones(batch_size, dtype=np.bool)
135+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136+
else:
137+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138+
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139+
return lam, use_cutmix
45140

46-
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
47-
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
141+
def _params_per_batch(self):
142+
lam = 1.
143+
use_cutmix = False
144+
if self.mixup_enabled and np.random.rand() < self.mix_prob:
145+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146+
use_cutmix = np.random.rand() < self.switch_prob
147+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148+
np.random.beta(self.mixup_alpha, self.mixup_alpha)
149+
elif self.mixup_alpha > 0.:
150+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151+
elif self.cutmix_alpha > 0.:
152+
use_cutmix = True
153+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154+
else:
155+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156+
lam = float(lam_mix)
157+
return lam, use_cutmix
48158

49-
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
159+
def _mix_elem(self, x):
160+
batch_size = len(x)
161+
lam_batch, use_cutmix = self._params_per_elem(batch_size)
162+
x_orig = x.clone() # need to keep an unmodified original for mixing source
50163
for i in range(batch_size):
51-
mixed = batch[i][0].astype(np.float32) * lam + \
52-
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
53-
np.round(mixed, out=mixed)
54-
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
164+
j = batch_size - i - 1
165+
lam = lam_batch[i]
166+
if lam != 1.:
167+
if use_cutmix[i]:
168+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
171+
lam_batch[i] = lam
172+
else:
173+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175+
176+
def _mix_batch(self, x):
177+
lam, use_cutmix = self._params_per_batch()
178+
if lam == 1.:
179+
return 1.
180+
if use_cutmix:
181+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
182+
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
183+
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
184+
else:
185+
x_flipped = x.flip(0).mul_(1. - lam)
186+
x.mul_(lam).add_(x_flipped)
187+
return lam
188+
189+
def __call__(self, x, target):
190+
assert len(x) % 2 == 0, 'Batch size should be even when using this'
191+
lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x)
192+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
193+
return x, target
194+
195+
196+
class FastCollateMixup(Mixup):
197+
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
198+
199+
A Mixup impl that's performed while collating the batches.
200+
"""
201+
202+
def _mix_elem_collate(self, output, batch):
203+
batch_size = len(batch)
204+
lam_batch, use_cutmix = self._params_per_elem(batch_size)
205+
for i in range(batch_size):
206+
j = batch_size - i - 1
207+
lam = lam_batch[i]
208+
mixed = batch[i][0]
209+
if lam != 1.:
210+
if use_cutmix[i]:
211+
mixed = mixed.copy()
212+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
213+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
214+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
215+
lam_batch[i] = lam
216+
else:
217+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
218+
lam_batch[i] = lam
219+
np.round(mixed, out=mixed)
220+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
221+
return torch.tensor(lam_batch).unsqueeze(1)
222+
223+
def _mix_batch_collate(self, output, batch):
224+
batch_size = len(batch)
225+
lam, use_cutmix = self._params_per_batch()
226+
if use_cutmix:
227+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
228+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
229+
for i in range(batch_size):
230+
j = batch_size - i - 1
231+
mixed = batch[i][0]
232+
if lam != 1.:
233+
if use_cutmix:
234+
mixed = mixed.copy() # don't want to modify the original while iterating
235+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
236+
else:
237+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
238+
np.round(mixed, out=mixed)
239+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
240+
return lam
241+
242+
def __call__(self, batch, _=None):
243+
batch_size = len(batch)
244+
assert batch_size % 2 == 0, 'Batch size should be even when using this'
245+
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
246+
if self.elementwise:
247+
lam = self._mix_elem_collate(output, batch)
248+
else:
249+
lam = self._mix_batch_collate(output, batch)
250+
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
251+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
252+
return output, target
55253

56-
return tensor, target

train.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torch.nn.parallel import DistributedDataParallel as DDP
2929
has_apex = False
3030

31-
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
31+
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
3232
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
3333
from timm.utils import *
3434
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
@@ -156,7 +156,17 @@
156156
parser.add_argument('--resplit', action='store_true', default=False,
157157
help='Do not random erase first (clean) augmentation split')
158158
parser.add_argument('--mixup', type=float, default=0.0,
159-
help='Mixup alpha, mixup enabled if > 0. (default: 0.)')
159+
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
160+
parser.add_argument('--cutmix', type=float, default=0.0,
161+
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
162+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
163+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
164+
parser.add_argument('--mixup-prob', type=float, default=1.0,
165+
help='Probability of performing mixup or cutmix when either/both is enabled')
166+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
167+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
168+
parser.add_argument('--mixup-elem', action='store_true', default=False,
169+
help='Apply mixup/cutmix params uniquely per batch element instead of per batch.')
160170
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
161171
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
162172
parser.add_argument('--smoothing', type=float, default=0.1,
@@ -388,9 +398,18 @@ def main():
388398
dataset_train = Dataset(train_dir)
389399

390400
collate_fn = None
391-
if args.prefetcher and args.mixup > 0:
392-
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
393-
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
401+
mixup_fn = None
402+
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
403+
if mixup_active:
404+
mixup_args = dict(
405+
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
406+
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem,
407+
label_smoothing=args.smoothing, num_classes=args.num_classes)
408+
if args.prefetcher:
409+
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
410+
collate_fn = FastCollateMixup(**mixup_args)
411+
else:
412+
mixup_fn = Mixup(**mixup_args)
394413

395414
if num_aug_splits > 1:
396415
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
@@ -452,17 +471,14 @@ def main():
452471
if args.jsd:
453472
assert num_aug_splits > 1 # JSD only valid with aug splits set
454473
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
455-
validate_loss_fn = nn.CrossEntropyLoss().cuda()
456-
elif args.mixup > 0.:
457-
# smoothing is handled with mixup label transform
474+
elif mixup_active:
475+
# smoothing is handled with mixup target transform
458476
train_loss_fn = SoftTargetCrossEntropy().cuda()
459-
validate_loss_fn = nn.CrossEntropyLoss().cuda()
460477
elif args.smoothing:
461478
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
462-
validate_loss_fn = nn.CrossEntropyLoss().cuda()
463479
else:
464480
train_loss_fn = nn.CrossEntropyLoss().cuda()
465-
validate_loss_fn = train_loss_fn
481+
validate_loss_fn = nn.CrossEntropyLoss().cuda()
466482

467483
eval_metric = args.eval_metric
468484
best_metric = None
@@ -490,7 +506,7 @@ def main():
490506
train_metrics = train_epoch(
491507
epoch, model, loader_train, optimizer, train_loss_fn, args,
492508
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
493-
use_amp=use_amp, model_ema=model_ema)
509+
use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn)
494510

495511
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
496512
if args.local_rank == 0:
@@ -530,11 +546,13 @@ def main():
530546

531547
def train_epoch(
532548
epoch, model, loader, optimizer, loss_fn, args,
533-
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
549+
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None):
534550

535-
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
536-
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
551+
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
552+
if args.prefetcher and loader.mixup_enabled:
537553
loader.mixup_enabled = False
554+
elif mixup_fn is not None:
555+
mixup_fn.mixup_enabled = False
538556

539557
batch_time_m = AverageMeter()
540558
data_time_m = AverageMeter()
@@ -550,11 +568,8 @@ def train_epoch(
550568
data_time_m.update(time.time() - end)
551569
if not args.prefetcher:
552570
input, target = input.cuda(), target.cuda()
553-
if args.mixup > 0.:
554-
input, target = mixup_batch(
555-
input, target,
556-
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
557-
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
571+
if mixup_fn is not None:
572+
input, target = mixup_fn(input, target)
558573

559574
output = model(input)
560575

0 commit comments

Comments
 (0)