1
- """ Mixup
2
- Paper: `mixup: Beyond Empirical Risk Minimization` - https://arxiv.org/abs/1710.09412
1
+ """ Mixup and Cutmix
2
+
3
+ Papers:
4
+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5
+
6
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7
+
8
+ Code Reference:
9
+ CutMix: https://github.com/clovaai/CutMix-PyTorch
3
10
4
11
Hacked together by / Copyright 2020 Ross Wightman
5
12
"""
@@ -17,40 +24,230 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
17
24
on_value = 1. - smoothing + off_value
18
25
y1 = one_hot (target , num_classes , on_value = on_value , off_value = off_value , device = device )
19
26
y2 = one_hot (target .flip (0 ), num_classes , on_value = on_value , off_value = off_value , device = device )
20
- return lam * y1 + (1. - lam )* y2
27
+ return y1 * lam + y2 * (1. - lam )
28
+
29
+
30
+ def rand_bbox (img_shape , lam , margin = 0. , count = None ):
31
+ """ Standard CutMix bounding-box
32
+ Generates a random square bbox based on lambda value. This impl includes
33
+ support for enforcing a border margin as percent of bbox dimensions.
34
+
35
+ Args:
36
+ img_shape (tuple): Image shape as tuple
37
+ lam (float): Cutmix lambda value
38
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39
+ count (int): Number of bbox to generate
40
+ """
41
+ ratio = np .sqrt (1 - lam )
42
+ img_h , img_w = img_shape [- 2 :]
43
+ cut_h , cut_w = int (img_h * ratio ), int (img_w * ratio )
44
+ margin_y , margin_x = int (margin * cut_h ), int (margin * cut_w )
45
+ cy = np .random .randint (0 + margin_y , img_h - margin_y , size = count )
46
+ cx = np .random .randint (0 + margin_x , img_w - margin_x , size = count )
47
+ yl = np .clip (cy - cut_h // 2 , 0 , img_h )
48
+ yh = np .clip (cy + cut_h // 2 , 0 , img_h )
49
+ xl = np .clip (cx - cut_w // 2 , 0 , img_w )
50
+ xh = np .clip (cx + cut_w // 2 , 0 , img_w )
51
+ return yl , yh , xl , xh
52
+
53
+
54
+ def rand_bbox_minmax (img_shape , minmax , count = None ):
55
+ """ Min-Max CutMix bounding-box
56
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
57
+ based on min/max percent values applied to each dimension of the input image.
21
58
59
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
22
60
23
- def mixup_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
24
- lam = 1.
25
- if not disable :
26
- lam = np .random .beta (alpha , alpha )
27
- input = input .mul (lam ).add_ (1 - lam , input .flip (0 ))
28
- target = mixup_target (target , num_classes , lam , smoothing )
29
- return input , target
61
+ Args:
62
+ img_shape (tuple): Image shape as tuple
63
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64
+ count (int): Number of bbox to generate
65
+ """
66
+ assert len (minmax ) == 2
67
+ img_h , img_w = img_shape [- 2 :]
68
+ cut_h = np .random .randint (int (img_h * minmax [0 ]), int (img_h * minmax [1 ]), size = count )
69
+ cut_w = np .random .randint (int (img_w * minmax [0 ]), int (img_w * minmax [1 ]), size = count )
70
+ yl = np .random .randint (0 , img_h - cut_h , size = count )
71
+ xl = np .random .randint (0 , img_w - cut_w , size = count )
72
+ yu = yl + cut_h
73
+ xu = xl + cut_w
74
+ return yl , yu , xl , xu
30
75
31
76
32
- class FastCollateMixup :
77
+ def cutmix_bbox_and_lam (img_shape , lam , ratio_minmax = None , correct_lam = True , count = None ):
78
+ """ Generate bbox and apply lambda correction.
79
+ """
80
+ if ratio_minmax is not None :
81
+ yl , yu , xl , xu = rand_bbox_minmax (img_shape , ratio_minmax , count = count )
82
+ else :
83
+ yl , yu , xl , xu = rand_bbox (img_shape , lam , count = count )
84
+ if correct_lam or ratio_minmax is not None :
85
+ bbox_area = (yu - yl ) * (xu - xl )
86
+ lam = 1. - bbox_area / float (img_shape [- 2 ] * img_shape [- 1 ])
87
+ return (yl , yu , xl , xu ), lam
33
88
34
- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 ):
89
+
90
+ class Mixup :
91
+ """ Mixup/Cutmix that applies different params to each element or whole batch
92
+
93
+ Args:
94
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97
+ prob (float): probability of applying mixup or cutmix per batch or element
98
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99
+ elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
100
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101
+ label_smoothing (float): apply label smoothing to the mixed target tensor
102
+ num_classes (int): number of classes for target
103
+ """
104
+ def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
105
+ elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
35
106
self .mixup_alpha = mixup_alpha
107
+ self .cutmix_alpha = cutmix_alpha
108
+ self .cutmix_minmax = cutmix_minmax
109
+ if self .cutmix_minmax is not None :
110
+ assert len (self .cutmix_minmax ) == 2
111
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112
+ self .cutmix_alpha = 1.0
113
+ self .mix_prob = prob
114
+ self .switch_prob = switch_prob
36
115
self .label_smoothing = label_smoothing
37
116
self .num_classes = num_classes
38
- self .mixup_enabled = True
117
+ self .elementwise = elementwise
118
+ self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119
+ self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
39
120
40
- def __call__ (self , batch ):
41
- batch_size = len ( batch )
42
- lam = 1.
121
+ def _params_per_elem (self , batch_size ):
122
+ lam = np . ones ( batch_size , dtype = np . float32 )
123
+ use_cutmix = np . zeros ( batch_size , dtype = np . bool )
43
124
if self .mixup_enabled :
44
- lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
125
+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
126
+ use_cutmix = np .random .rand (batch_size ) < self .switch_prob
127
+ lam_mix = np .where (
128
+ use_cutmix ,
129
+ np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size ),
130
+ np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size ))
131
+ elif self .mixup_alpha > 0. :
132
+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size )
133
+ elif self .cutmix_alpha > 0. :
134
+ use_cutmix = np .ones (batch_size , dtype = np .bool )
135
+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
136
+ else :
137
+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138
+ lam = np .where (np .random .rand (batch_size ) < self .mix_prob , lam_mix .astype (np .float32 ), lam )
139
+ return lam , use_cutmix
45
140
46
- target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
47
- target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
141
+ def _params_per_batch (self ):
142
+ lam = 1.
143
+ use_cutmix = False
144
+ if self .mixup_enabled and np .random .rand () < self .mix_prob :
145
+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
146
+ use_cutmix = np .random .rand () < self .switch_prob
147
+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha ) if use_cutmix else \
148
+ np .random .beta (self .mixup_alpha , self .mixup_alpha )
149
+ elif self .mixup_alpha > 0. :
150
+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha )
151
+ elif self .cutmix_alpha > 0. :
152
+ use_cutmix = True
153
+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha )
154
+ else :
155
+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156
+ lam = float (lam_mix )
157
+ return lam , use_cutmix
48
158
49
- tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
159
+ def _mix_elem (self , x ):
160
+ batch_size = len (x )
161
+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
162
+ x_orig = x .clone () # need to keep an unmodified original for mixing source
50
163
for i in range (batch_size ):
51
- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
52
- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
53
- np .round (mixed , out = mixed )
54
- tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
164
+ j = batch_size - i - 1
165
+ lam = lam_batch [i ]
166
+ if lam != 1. :
167
+ if use_cutmix [i ]:
168
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
169
+ x [i ].shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
170
+ x [i ][:, yl :yh , xl :xh ] = x_orig [j ][:, yl :yh , xl :xh ]
171
+ lam_batch [i ] = lam
172
+ else :
173
+ x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
174
+ return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
175
+
176
+ def _mix_batch (self , x ):
177
+ lam , use_cutmix = self ._params_per_batch ()
178
+ if lam == 1. :
179
+ return 1.
180
+ if use_cutmix :
181
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
182
+ x .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
183
+ x [:, :, yl :yh , xl :xh ] = x .flip (0 )[:, :, yl :yh , xl :xh ]
184
+ else :
185
+ x_flipped = x .flip (0 ).mul_ (1. - lam )
186
+ x .mul_ (lam ).add_ (x_flipped )
187
+ return lam
188
+
189
+ def __call__ (self , x , target ):
190
+ assert len (x ) % 2 == 0 , 'Batch size should be even when using this'
191
+ lam = self ._mix_elem (x ) if self .elementwise else self ._mix_batch (x )
192
+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing )
193
+ return x , target
194
+
195
+
196
+ class FastCollateMixup (Mixup ):
197
+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
198
+
199
+ A Mixup impl that's performed while collating the batches.
200
+ """
201
+
202
+ def _mix_elem_collate (self , output , batch ):
203
+ batch_size = len (batch )
204
+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
205
+ for i in range (batch_size ):
206
+ j = batch_size - i - 1
207
+ lam = lam_batch [i ]
208
+ mixed = batch [i ][0 ]
209
+ if lam != 1. :
210
+ if use_cutmix [i ]:
211
+ mixed = mixed .copy ()
212
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
213
+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
214
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
215
+ lam_batch [i ] = lam
216
+ else :
217
+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
218
+ lam_batch [i ] = lam
219
+ np .round (mixed , out = mixed )
220
+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
221
+ return torch .tensor (lam_batch ).unsqueeze (1 )
222
+
223
+ def _mix_batch_collate (self , output , batch ):
224
+ batch_size = len (batch )
225
+ lam , use_cutmix = self ._params_per_batch ()
226
+ if use_cutmix :
227
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
228
+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
229
+ for i in range (batch_size ):
230
+ j = batch_size - i - 1
231
+ mixed = batch [i ][0 ]
232
+ if lam != 1. :
233
+ if use_cutmix :
234
+ mixed = mixed .copy () # don't want to modify the original while iterating
235
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
236
+ else :
237
+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
238
+ np .round (mixed , out = mixed )
239
+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
240
+ return lam
241
+
242
+ def __call__ (self , batch , _ = None ):
243
+ batch_size = len (batch )
244
+ assert batch_size % 2 == 0 , 'Batch size should be even when using this'
245
+ output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
246
+ if self .elementwise :
247
+ lam = self ._mix_elem_collate (output , batch )
248
+ else :
249
+ lam = self ._mix_batch_collate (output , batch )
250
+ target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
251
+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
252
+ return output , target
55
253
56
- return tensor , target
0 commit comments