10
10
11
11
Hacked together by / Copyright 2020 Ross Wightman
12
12
"""
13
-
14
13
import numpy as np
15
14
import torch
16
- import math
17
- import numbers
18
15
19
16
20
17
def one_hot (x , num_classes , on_value = 1. , off_value = 0. , device = 'cuda' ):
@@ -30,20 +27,21 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
30
27
return y1 * lam + y2 * (1. - lam )
31
28
32
29
33
- def mixup_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
34
- lam = 1.
35
- if not disable :
36
- lam = np .random .beta (alpha , alpha )
37
- input = input .mul (lam ).add_ (1 - lam , input .flip (0 ))
38
- target = mixup_target (target , num_classes , lam , smoothing )
39
- return input , target
40
-
30
+ def rand_bbox (img_shape , lam , margin = 0. , count = None ):
31
+ """ Standard CutMix bounding-box
32
+ Generates a random square bbox based on lambda value. This impl includes
33
+ support for enforcing a border margin as percent of bbox dimensions.
41
34
42
- def rand_bbox (size , lam , border = 0. , count = None ):
43
- ratio = math .sqrt (1 - lam )
44
- img_h , img_w = size [- 2 :]
35
+ Args:
36
+ img_shape (tuple): Image shape as tuple
37
+ lam (float): Cutmix lambda value
38
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39
+ count (int): Number of bbox to generate
40
+ """
41
+ ratio = np .sqrt (1 - lam )
42
+ img_h , img_w = img_shape [- 2 :]
45
43
cut_h , cut_w = int (img_h * ratio ), int (img_w * ratio )
46
- margin_y , margin_x = int (border * cut_h ), int (border * cut_w )
44
+ margin_y , margin_x = int (margin * cut_h ), int (margin * cut_w )
47
45
cy = np .random .randint (0 + margin_y , img_h - margin_y , size = count )
48
46
cx = np .random .randint (0 + margin_x , img_w - margin_x , size = count )
49
47
yl = np .clip (cy - cut_h // 2 , 0 , img_h )
@@ -53,9 +51,20 @@ def rand_bbox(size, lam, border=0., count=None):
53
51
return yl , yh , xl , xh
54
52
55
53
56
- def rand_bbox_minmax (size , minmax , count = None ):
54
+ def rand_bbox_minmax (img_shape , minmax , count = None ):
55
+ """ Min-Max CutMix bounding-box
56
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
57
+ based on min/max percent values applied to each dimension of the input image.
58
+
59
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60
+
61
+ Args:
62
+ img_shape (tuple): Image shape as tuple
63
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64
+ count (int): Number of bbox to generate
65
+ """
57
66
assert len (minmax ) == 2
58
- img_h , img_w = size [- 2 :]
67
+ img_h , img_w = img_shape [- 2 :]
59
68
cut_h = np .random .randint (int (img_h * minmax [0 ]), int (img_h * minmax [1 ]), size = count )
60
69
cut_w = np .random .randint (int (img_w * minmax [0 ]), int (img_w * minmax [1 ]), size = count )
61
70
yl = np .random .randint (0 , img_h - cut_h , size = count )
@@ -66,6 +75,8 @@ def rand_bbox_minmax(size, minmax, count=None):
66
75
67
76
68
77
def cutmix_bbox_and_lam (img_shape , lam , ratio_minmax = None , correct_lam = True , count = None ):
78
+ """ Generate bbox and apply lambda correction.
79
+ """
69
80
if ratio_minmax is not None :
70
81
yl , yu , xl , xu = rand_bbox_minmax (img_shape , ratio_minmax , count = count )
71
82
else :
@@ -76,71 +87,40 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou
76
87
return (yl , yu , xl , xu ), lam
77
88
78
89
79
- def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , correct_lam = False ):
80
- lam = 1.
81
- if not disable :
82
- lam = np .random .beta (alpha , alpha )
83
- if lam != 1 :
84
- yl , yh , xl , xh = rand_bbox (input .size (), lam )
85
- input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
86
- if correct_lam :
87
- lam = 1. - (yh - yl ) * (xh - xl ) / float (input .shape [- 2 ] * input .shape [- 1 ])
88
- target = mixup_target (target , num_classes , lam , smoothing )
89
- return input , target
90
-
91
-
92
- def mix_batch (
93
- input , target , mixup_alpha = 0.2 , cutmix_alpha = 0. , prob = 1.0 , switch_prob = .5 ,
94
- num_classes = 1000 , smoothing = 0.1 , disable = False ):
95
- # FIXME test this version
96
- if np .random .rand () > prob :
97
- return input , target
98
- use_cutmix = cutmix_alpha > 0. and np .random .rand () <= switch_prob
99
- if use_cutmix :
100
- return cutmix_batch (input , target , cutmix_alpha , num_classes , smoothing , disable )
101
- else :
102
- return mixup_batch (input , target , mixup_alpha , num_classes , smoothing , disable )
103
-
104
-
105
- class FastCollateMixup :
106
- """Fast Collate Mixup/Cutmix that applies different params to each element or whole batch
107
-
108
- NOTE once experiments are done, one of the three variants will remain with this class name
90
+ class Mixup :
91
+ """ Mixup/Cutmix that applies different params to each element or whole batch
109
92
93
+ Args:
94
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97
+ prob (float): probability of applying mixup or cutmix per batch or element
98
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99
+ elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
100
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101
+ label_smoothing (float): apply label smoothing to the mixed target tensor
102
+ num_classes (int): number of classes for target
110
103
"""
111
104
def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
112
105
elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
113
- """
114
-
115
- Args:
116
- mixup_alpha (float): mixup alpha value, mixup is active if > 0.
117
- cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
118
- cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None
119
- prob (float): probability of applying mixup or cutmix per batch or element
120
- switch_prob (float): probability of using cutmix instead of mixup when both active
121
- elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
122
- label_smoothing (float):
123
- num_classes (int):
124
- """
125
106
self .mixup_alpha = mixup_alpha
126
107
self .cutmix_alpha = cutmix_alpha
127
108
self .cutmix_minmax = cutmix_minmax
128
109
if self .cutmix_minmax is not None :
129
110
assert len (self .cutmix_minmax ) == 2
130
111
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
131
112
self .cutmix_alpha = 1.0
132
- self .prob = prob
113
+ self .mix_prob = prob
133
114
self .switch_prob = switch_prob
134
115
self .label_smoothing = label_smoothing
135
116
self .num_classes = num_classes
136
117
self .elementwise = elementwise
137
118
self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
138
119
self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
139
120
140
- def _mix_elem (self , output , batch ):
141
- batch_size = len (batch )
142
- lam_out = np .ones (batch_size , dtype = np .float32 )
143
- use_cutmix = np .zeros (batch_size ).astype (np .bool )
121
+ def _params_per_elem (self , batch_size ):
122
+ lam = np .ones (batch_size , dtype = np .float32 )
123
+ use_cutmix = np .zeros (batch_size , dtype = np .bool )
144
124
if self .mixup_enabled :
145
125
if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
146
126
use_cutmix = np .random .rand (batch_size ) < self .switch_prob
@@ -151,35 +131,17 @@ def _mix_elem(self, output, batch):
151
131
elif self .mixup_alpha > 0. :
152
132
lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size )
153
133
elif self .cutmix_alpha > 0. :
154
- use_cutmix = np .ones (batch_size ). astype ( np .bool )
134
+ use_cutmix = np .ones (batch_size , dtype = np .bool )
155
135
lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
156
136
else :
157
137
assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158
- lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix .astype (np .float32 ), lam_out )
159
-
160
- for i in range (batch_size ):
161
- j = batch_size - i - 1
162
- lam = lam_out [i ]
163
- mixed = batch [i ][0 ]
164
- if lam != 1. :
165
- if use_cutmix [i ]:
166
- mixed = mixed .copy ()
167
- (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
168
- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
169
- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
170
- lam_out [i ] = lam
171
- else :
172
- mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
173
- lam_out [i ] = lam
174
- np .round (mixed , out = mixed )
175
- output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
176
- return torch .tensor (lam_out ).unsqueeze (1 )
138
+ lam = np .where (np .random .rand (batch_size ) < self .mix_prob , lam_mix .astype (np .float32 ), lam )
139
+ return lam , use_cutmix
177
140
178
- def _mix_batch (self , output , batch ):
179
- batch_size = len (batch )
141
+ def _params_per_batch (self ):
180
142
lam = 1.
181
143
use_cutmix = False
182
- if self .mixup_enabled and np .random .rand () < self .prob :
144
+ if self .mixup_enabled and np .random .rand () < self .mix_prob :
183
145
if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
184
146
use_cutmix = np .random .rand () < self .switch_prob
185
147
lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha ) if use_cutmix else \
@@ -192,34 +154,100 @@ def _mix_batch(self, output, batch):
192
154
else :
193
155
assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
194
156
lam = float (lam_mix )
157
+ return lam , use_cutmix
195
158
159
+ def _mix_elem (self , x ):
160
+ batch_size = len (x )
161
+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
162
+ x_orig = x .clone () # need to keep an unmodified original for mixing source
163
+ for i in range (batch_size ):
164
+ j = batch_size - i - 1
165
+ lam = lam_batch [i ]
166
+ if lam != 1. :
167
+ if use_cutmix [i ]:
168
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
169
+ x [i ].shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
170
+ x [i ][:, yl :yh , xl :xh ] = x_orig [j ][:, yl :yh , xl :xh ]
171
+ lam_batch [i ] = lam
172
+ else :
173
+ x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
174
+ return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
175
+
176
+ def _mix_batch (self , x ):
177
+ lam , use_cutmix = self ._params_per_batch ()
178
+ if lam == 1. :
179
+ return 1.
196
180
if use_cutmix :
197
181
(yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
198
- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
182
+ x .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
183
+ x [:, :, yl :yh , xl :xh ] = x .flip (0 )[:, :, yl :yh , xl :xh ]
184
+ else :
185
+ x_flipped = x .flip (0 ).mul_ (1. - lam )
186
+ x .mul_ (lam ).add_ (x_flipped )
187
+ return lam
188
+
189
+ def __call__ (self , x , target ):
190
+ assert len (x ) % 2 == 0 , 'Batch size should be even when using this'
191
+ lam = self ._mix_elem (x ) if self .elementwise else self ._mix_batch (x )
192
+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing )
193
+ return x , target
194
+
195
+
196
+ class FastCollateMixup (Mixup ):
197
+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
199
198
199
+ A Mixup impl that's performed while collating the batches.
200
+ """
201
+
202
+ def _mix_elem_collate (self , output , batch ):
203
+ batch_size = len (batch )
204
+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
200
205
for i in range (batch_size ):
201
206
j = batch_size - i - 1
207
+ lam = lam_batch [i ]
202
208
mixed = batch [i ][0 ]
203
209
if lam != 1. :
204
- if use_cutmix :
210
+ if use_cutmix [ i ] :
205
211
mixed = mixed .copy ()
212
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
213
+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
214
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
215
+ lam_batch [i ] = lam
216
+ else :
217
+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
218
+ lam_batch [i ] = lam
219
+ np .round (mixed , out = mixed )
220
+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
221
+ return torch .tensor (lam_batch ).unsqueeze (1 )
222
+
223
+ def _mix_batch_collate (self , output , batch ):
224
+ batch_size = len (batch )
225
+ lam , use_cutmix = self ._params_per_batch ()
226
+ if use_cutmix :
227
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
228
+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
229
+ for i in range (batch_size ):
230
+ j = batch_size - i - 1
231
+ mixed = batch [i ][0 ]
232
+ if lam != 1. :
233
+ if use_cutmix :
234
+ mixed = mixed .copy () # don't want to modify the original while iterating
206
235
mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
207
236
else :
208
237
mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
209
238
np .round (mixed , out = mixed )
210
239
output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
211
240
return lam
212
241
213
- def __call__ (self , batch ):
242
+ def __call__ (self , batch , _ = None ):
214
243
batch_size = len (batch )
215
244
assert batch_size % 2 == 0 , 'Batch size should be even when using this'
216
245
output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
217
246
if self .elementwise :
218
- lam = self ._mix_elem (output , batch )
247
+ lam = self ._mix_elem_collate (output , batch )
219
248
else :
220
- lam = self ._mix_batch (output , batch )
249
+ lam = self ._mix_batch_collate (output , batch )
221
250
target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
222
251
target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
223
-
224
252
return output , target
225
253
0 commit comments