@@ -72,7 +72,7 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou
72
72
yl , yu , xl , xu = rand_bbox (img_shape , lam , count = count )
73
73
if correct_lam or ratio_minmax is not None :
74
74
bbox_area = (yu - yl ) * (xu - xl )
75
- lam = 1. - bbox_area / (img_shape [- 2 ] * img_shape [- 1 ])
75
+ lam = 1. - bbox_area / float (img_shape [- 2 ] * img_shape [- 1 ])
76
76
return (yl , yu , xl , xu ), lam
77
77
78
78
@@ -84,7 +84,7 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa
84
84
yl , yh , xl , xh = rand_bbox (input .size (), lam )
85
85
input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
86
86
if correct_lam :
87
- lam = 1 - (yh - yl ) * (xh - xl ) / (input .shape [- 2 ] * input .shape [- 1 ])
87
+ lam = 1. - (yh - yl ) * (xh - xl ) / float (input .shape [- 2 ] * input .shape [- 1 ])
88
88
target = mixup_target (target , num_classes , lam , smoothing )
89
89
return input , target
90
90
@@ -139,7 +139,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
139
139
140
140
def _mix_elem (self , output , batch ):
141
141
batch_size = len (batch )
142
- lam_out = np .ones (batch_size )
142
+ lam_out = np .ones (batch_size , dtype = np . float32 )
143
143
use_cutmix = np .zeros (batch_size ).astype (np .bool )
144
144
if self .mixup_enabled :
145
145
if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
@@ -155,22 +155,23 @@ def _mix_elem(self, output, batch):
155
155
lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
156
156
else :
157
157
assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158
- lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix , lam_out )
158
+ lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix . astype ( np . float32 ) , lam_out )
159
159
160
160
for i in range (batch_size ):
161
161
j = batch_size - i - 1
162
162
lam = lam_out [i ]
163
- mixed = batch [i ][0 ]. astype ( np . float32 )
163
+ mixed = batch [i ][0 ]
164
164
if lam != 1. :
165
165
if use_cutmix [i ]:
166
+ mixed = mixed .copy ()
166
167
(yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
167
168
output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
168
- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]. astype ( np . float32 )
169
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
169
170
lam_out [i ] = lam
170
171
else :
171
- mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
172
+ mixed = mixed . astype ( np . float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
172
173
lam_out [i ] = lam
173
- np .round (mixed , out = mixed )
174
+ np .round (mixed , out = mixed )
174
175
output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
175
176
return torch .tensor (lam_out ).unsqueeze (1 )
176
177
@@ -190,21 +191,22 @@ def _mix_batch(self, output, batch):
190
191
lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha )
191
192
else :
192
193
assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
193
- lam = lam_mix
194
+ lam = float ( lam_mix )
194
195
195
196
if use_cutmix :
196
197
(yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
197
198
output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
198
199
199
200
for i in range (batch_size ):
200
201
j = batch_size - i - 1
201
- mixed = batch [i ][0 ]. astype ( np . float32 )
202
+ mixed = batch [i ][0 ]
202
203
if lam != 1. :
203
204
if use_cutmix :
204
- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
205
+ mixed = mixed .copy ()
206
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
205
207
else :
206
- mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
207
- np .round (mixed , out = mixed )
208
+ mixed = mixed . astype ( np . float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
209
+ np .round (mixed , out = mixed )
208
210
output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
209
211
return lam
210
212
0 commit comments