Skip to content

Commit 8c9814e

Browse files
committed
Final cleanup of mixup/cutmix. Element/batch modes working with both collate (prefetcher active) and without prefetcher.
1 parent cd23f55 commit 8c9814e

File tree

3 files changed

+143
-114
lines changed

3 files changed

+143
-114
lines changed

timm/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .transforms import *
55
from .loader import create_loader
66
from .transforms_factory import create_transform
7-
from .mixup import mix_batch, FastCollateMixup
7+
from .mixup import Mixup, FastCollateMixup
88
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
99
rand_augment_transform, auto_augment_transform
1010
from .real_labels import RealLabelsImagenet

timm/data/mixup.py

Lines changed: 121 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,8 @@
1010
1111
Hacked together by / Copyright 2020 Ross Wightman
1212
"""
13-
1413
import numpy as np
1514
import torch
16-
import math
17-
import numbers
1815

1916

2017
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
@@ -30,20 +27,21 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
3027
return y1 * lam + y2 * (1. - lam)
3128

3229

33-
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
34-
lam = 1.
35-
if not disable:
36-
lam = np.random.beta(alpha, alpha)
37-
input = input.mul(lam).add_(1 - lam, input.flip(0))
38-
target = mixup_target(target, num_classes, lam, smoothing)
39-
return input, target
40-
30+
def rand_bbox(img_shape, lam, margin=0., count=None):
31+
""" Standard CutMix bounding-box
32+
Generates a random square bbox based on lambda value. This impl includes
33+
support for enforcing a border margin as percent of bbox dimensions.
4134
42-
def rand_bbox(size, lam, border=0., count=None):
43-
ratio = math.sqrt(1 - lam)
44-
img_h, img_w = size[-2:]
35+
Args:
36+
img_shape (tuple): Image shape as tuple
37+
lam (float): Cutmix lambda value
38+
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39+
count (int): Number of bbox to generate
40+
"""
41+
ratio = np.sqrt(1 - lam)
42+
img_h, img_w = img_shape[-2:]
4543
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
46-
margin_y, margin_x = int(border * cut_h), int(border * cut_w)
44+
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
4745
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
4846
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
4947
yl = np.clip(cy - cut_h // 2, 0, img_h)
@@ -53,9 +51,20 @@ def rand_bbox(size, lam, border=0., count=None):
5351
return yl, yh, xl, xh
5452

5553

56-
def rand_bbox_minmax(size, minmax, count=None):
54+
def rand_bbox_minmax(img_shape, minmax, count=None):
55+
""" Min-Max CutMix bounding-box
56+
Inspired by Darknet cutmix impl, generates a random rectangular bbox
57+
based on min/max percent values applied to each dimension of the input image.
58+
59+
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60+
61+
Args:
62+
img_shape (tuple): Image shape as tuple
63+
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64+
count (int): Number of bbox to generate
65+
"""
5766
assert len(minmax) == 2
58-
img_h, img_w = size[-2:]
67+
img_h, img_w = img_shape[-2:]
5968
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
6069
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
6170
yl = np.random.randint(0, img_h - cut_h, size=count)
@@ -66,6 +75,8 @@ def rand_bbox_minmax(size, minmax, count=None):
6675

6776

6877
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78+
""" Generate bbox and apply lambda correction.
79+
"""
6980
if ratio_minmax is not None:
7081
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
7182
else:
@@ -76,71 +87,40 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou
7687
return (yl, yu, xl, xu), lam
7788

7889

79-
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False):
80-
lam = 1.
81-
if not disable:
82-
lam = np.random.beta(alpha, alpha)
83-
if lam != 1:
84-
yl, yh, xl, xh = rand_bbox(input.size(), lam)
85-
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
86-
if correct_lam:
87-
lam = 1. - (yh - yl) * (xh - xl) / float(input.shape[-2] * input.shape[-1])
88-
target = mixup_target(target, num_classes, lam, smoothing)
89-
return input, target
90-
91-
92-
def mix_batch(
93-
input, target, mixup_alpha=0.2, cutmix_alpha=0., prob=1.0, switch_prob=.5,
94-
num_classes=1000, smoothing=0.1, disable=False):
95-
# FIXME test this version
96-
if np.random.rand() > prob:
97-
return input, target
98-
use_cutmix = cutmix_alpha > 0. and np.random.rand() <= switch_prob
99-
if use_cutmix:
100-
return cutmix_batch(input, target, cutmix_alpha, num_classes, smoothing, disable)
101-
else:
102-
return mixup_batch(input, target, mixup_alpha, num_classes, smoothing, disable)
103-
104-
105-
class FastCollateMixup:
106-
"""Fast Collate Mixup/Cutmix that applies different params to each element or whole batch
107-
108-
NOTE once experiments are done, one of the three variants will remain with this class name
90+
class Mixup:
91+
""" Mixup/Cutmix that applies different params to each element or whole batch
10992
93+
Args:
94+
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95+
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96+
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97+
prob (float): probability of applying mixup or cutmix per batch or element
98+
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99+
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
100+
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101+
label_smoothing (float): apply label smoothing to the mixed target tensor
102+
num_classes (int): number of classes for target
110103
"""
111104
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
112105
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
113-
"""
114-
115-
Args:
116-
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
117-
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
118-
cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None
119-
prob (float): probability of applying mixup or cutmix per batch or element
120-
switch_prob (float): probability of using cutmix instead of mixup when both active
121-
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
122-
label_smoothing (float):
123-
num_classes (int):
124-
"""
125106
self.mixup_alpha = mixup_alpha
126107
self.cutmix_alpha = cutmix_alpha
127108
self.cutmix_minmax = cutmix_minmax
128109
if self.cutmix_minmax is not None:
129110
assert len(self.cutmix_minmax) == 2
130111
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
131112
self.cutmix_alpha = 1.0
132-
self.prob = prob
113+
self.mix_prob = prob
133114
self.switch_prob = switch_prob
134115
self.label_smoothing = label_smoothing
135116
self.num_classes = num_classes
136117
self.elementwise = elementwise
137118
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
138119
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
139120

140-
def _mix_elem(self, output, batch):
141-
batch_size = len(batch)
142-
lam_out = np.ones(batch_size, dtype=np.float32)
143-
use_cutmix = np.zeros(batch_size).astype(np.bool)
121+
def _params_per_elem(self, batch_size):
122+
lam = np.ones(batch_size, dtype=np.float32)
123+
use_cutmix = np.zeros(batch_size, dtype=np.bool)
144124
if self.mixup_enabled:
145125
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146126
use_cutmix = np.random.rand(batch_size) < self.switch_prob
@@ -151,35 +131,17 @@ def _mix_elem(self, output, batch):
151131
elif self.mixup_alpha > 0.:
152132
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
153133
elif self.cutmix_alpha > 0.:
154-
use_cutmix = np.ones(batch_size).astype(np.bool)
134+
use_cutmix = np.ones(batch_size, dtype=np.bool)
155135
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
156136
else:
157137
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158-
lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix.astype(np.float32), lam_out)
159-
160-
for i in range(batch_size):
161-
j = batch_size - i - 1
162-
lam = lam_out[i]
163-
mixed = batch[i][0]
164-
if lam != 1.:
165-
if use_cutmix[i]:
166-
mixed = mixed.copy()
167-
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
168-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
169-
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
170-
lam_out[i] = lam
171-
else:
172-
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
173-
lam_out[i] = lam
174-
np.round(mixed, out=mixed)
175-
output[i] += torch.from_numpy(mixed.astype(np.uint8))
176-
return torch.tensor(lam_out).unsqueeze(1)
138+
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139+
return lam, use_cutmix
177140

178-
def _mix_batch(self, output, batch):
179-
batch_size = len(batch)
141+
def _params_per_batch(self):
180142
lam = 1.
181143
use_cutmix = False
182-
if self.mixup_enabled and np.random.rand() < self.prob:
144+
if self.mixup_enabled and np.random.rand() < self.mix_prob:
183145
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
184146
use_cutmix = np.random.rand() < self.switch_prob
185147
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
@@ -192,34 +154,100 @@ def _mix_batch(self, output, batch):
192154
else:
193155
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
194156
lam = float(lam_mix)
157+
return lam, use_cutmix
195158

159+
def _mix_elem(self, x):
160+
batch_size = len(x)
161+
lam_batch, use_cutmix = self._params_per_elem(batch_size)
162+
x_orig = x.clone() # need to keep an unmodified original for mixing source
163+
for i in range(batch_size):
164+
j = batch_size - i - 1
165+
lam = lam_batch[i]
166+
if lam != 1.:
167+
if use_cutmix[i]:
168+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
171+
lam_batch[i] = lam
172+
else:
173+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175+
176+
def _mix_batch(self, x):
177+
lam, use_cutmix = self._params_per_batch()
178+
if lam == 1.:
179+
return 1.
196180
if use_cutmix:
197181
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
198-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
182+
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
183+
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
184+
else:
185+
x_flipped = x.flip(0).mul_(1. - lam)
186+
x.mul_(lam).add_(x_flipped)
187+
return lam
188+
189+
def __call__(self, x, target):
190+
assert len(x) % 2 == 0, 'Batch size should be even when using this'
191+
lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x)
192+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
193+
return x, target
194+
195+
196+
class FastCollateMixup(Mixup):
197+
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
199198
199+
A Mixup impl that's performed while collating the batches.
200+
"""
201+
202+
def _mix_elem_collate(self, output, batch):
203+
batch_size = len(batch)
204+
lam_batch, use_cutmix = self._params_per_elem(batch_size)
200205
for i in range(batch_size):
201206
j = batch_size - i - 1
207+
lam = lam_batch[i]
202208
mixed = batch[i][0]
203209
if lam != 1.:
204-
if use_cutmix:
210+
if use_cutmix[i]:
205211
mixed = mixed.copy()
212+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
213+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
214+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
215+
lam_batch[i] = lam
216+
else:
217+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
218+
lam_batch[i] = lam
219+
np.round(mixed, out=mixed)
220+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
221+
return torch.tensor(lam_batch).unsqueeze(1)
222+
223+
def _mix_batch_collate(self, output, batch):
224+
batch_size = len(batch)
225+
lam, use_cutmix = self._params_per_batch()
226+
if use_cutmix:
227+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
228+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
229+
for i in range(batch_size):
230+
j = batch_size - i - 1
231+
mixed = batch[i][0]
232+
if lam != 1.:
233+
if use_cutmix:
234+
mixed = mixed.copy() # don't want to modify the original while iterating
206235
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
207236
else:
208237
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
209238
np.round(mixed, out=mixed)
210239
output[i] += torch.from_numpy(mixed.astype(np.uint8))
211240
return lam
212241

213-
def __call__(self, batch):
242+
def __call__(self, batch, _=None):
214243
batch_size = len(batch)
215244
assert batch_size % 2 == 0, 'Batch size should be even when using this'
216245
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
217246
if self.elementwise:
218-
lam = self._mix_elem(output, batch)
247+
lam = self._mix_elem_collate(output, batch)
219248
else:
220-
lam = self._mix_batch(output, batch)
249+
lam = self._mix_batch_collate(output, batch)
221250
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
222251
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
223-
224252
return output, target
225253

0 commit comments

Comments
 (0)