Skip to content

Commit cd23f55

Browse files
committed
Fix mixed prec issues with new mixup code
1 parent f471c17 commit cd23f55

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

timm/data/mixup.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou
7272
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
7373
if correct_lam or ratio_minmax is not None:
7474
bbox_area = (yu - yl) * (xu - xl)
75-
lam = 1. - bbox_area / (img_shape[-2] * img_shape[-1])
75+
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
7676
return (yl, yu, xl, xu), lam
7777

7878

@@ -84,7 +84,7 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa
8484
yl, yh, xl, xh = rand_bbox(input.size(), lam)
8585
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
8686
if correct_lam:
87-
lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1])
87+
lam = 1. - (yh - yl) * (xh - xl) / float(input.shape[-2] * input.shape[-1])
8888
target = mixup_target(target, num_classes, lam, smoothing)
8989
return input, target
9090

@@ -139,7 +139,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
139139

140140
def _mix_elem(self, output, batch):
141141
batch_size = len(batch)
142-
lam_out = np.ones(batch_size)
142+
lam_out = np.ones(batch_size, dtype=np.float32)
143143
use_cutmix = np.zeros(batch_size).astype(np.bool)
144144
if self.mixup_enabled:
145145
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
@@ -155,22 +155,23 @@ def _mix_elem(self, output, batch):
155155
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
156156
else:
157157
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158-
lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix, lam_out)
158+
lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix.astype(np.float32), lam_out)
159159

160160
for i in range(batch_size):
161161
j = batch_size - i - 1
162162
lam = lam_out[i]
163-
mixed = batch[i][0].astype(np.float32)
163+
mixed = batch[i][0]
164164
if lam != 1.:
165165
if use_cutmix[i]:
166+
mixed = mixed.copy()
166167
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
167168
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
168-
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
169+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
169170
lam_out[i] = lam
170171
else:
171-
mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam)
172+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
172173
lam_out[i] = lam
173-
np.round(mixed, out=mixed)
174+
np.round(mixed, out=mixed)
174175
output[i] += torch.from_numpy(mixed.astype(np.uint8))
175176
return torch.tensor(lam_out).unsqueeze(1)
176177

@@ -190,21 +191,22 @@ def _mix_batch(self, output, batch):
190191
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
191192
else:
192193
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
193-
lam = lam_mix
194+
lam = float(lam_mix)
194195

195196
if use_cutmix:
196197
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
197198
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
198199

199200
for i in range(batch_size):
200201
j = batch_size - i - 1
201-
mixed = batch[i][0].astype(np.float32)
202+
mixed = batch[i][0]
202203
if lam != 1.:
203204
if use_cutmix:
204-
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
205+
mixed = mixed.copy()
206+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
205207
else:
206-
mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam)
207-
np.round(mixed, out=mixed)
208+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
209+
np.round(mixed, out=mixed)
208210
output[i] += torch.from_numpy(mixed.astype(np.uint8))
209211
return lam
210212

0 commit comments

Comments
 (0)