|
51 | 51 |
|
52 | 52 |
|
53 | 53 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
| 54 | + |
54 | 55 | # Dataset / Model parameters
|
55 | 56 | parser.add_argument('data', metavar='DIR',
|
56 | 57 | help='path to dataset')
|
|
82 | 83 | help='input batch size for training (default: 32)')
|
83 | 84 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
|
84 | 85 | help='ratio of validation batch size to training batch size (default: 1)')
|
85 |
| -parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', |
86 |
| - help='Dropout rate (default: 0.)') |
87 |
| -parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', |
88 |
| - help='Drop connect rate, DEPRECATED, use drop-path (default: None)') |
89 |
| -parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', |
90 |
| - help='Drop path rate (default: None)') |
91 |
| -parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', |
92 |
| - help='Drop block rate (default: None)') |
93 |
| -parser.add_argument('--jsd', action='store_true', default=False, |
94 |
| - help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') |
| 86 | + |
95 | 87 | # Optimizer parameters
|
96 | 88 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
97 | 89 | help='Optimizer (default: "sgd"')
|
|
101 | 93 | help='SGD momentum (default: 0.9)')
|
102 | 94 | parser.add_argument('--weight-decay', type=float, default=0.0001,
|
103 | 95 | help='weight decay (default: 0.0001)')
|
| 96 | + |
104 | 97 | # Learning rate schedule parameters
|
105 | 98 | parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
106 | 99 | help='LR scheduler (default: "step"')
|
|
134 | 127 | help='patience epochs for Plateau LR scheduler (default: 10')
|
135 | 128 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
136 | 129 | help='LR decay rate (default: 0.1)')
|
137 |
| -# Augmentation parameters |
| 130 | + |
| 131 | +# Augmentation & regularization parameters |
| 132 | +parser.add_argument('--no-aug', action='store_true', default=False, |
| 133 | + help='Disable all training augmentation, override other train aug args') |
| 134 | +parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', |
| 135 | + help='Random resize scale (default: 0.08 1.0)') |
| 136 | +parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', |
| 137 | + help='Random resize aspect ratio (default: 0.75 1.33)') |
| 138 | +parser.add_argument('--hflip', type=float, default=0.5, |
| 139 | + help='Horizontal flip training aug probability') |
| 140 | +parser.add_argument('--vflip', type=float, default=0., |
| 141 | + help='Vertical flip training aug probability') |
138 | 142 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
139 | 143 | help='Color jitter factor (default: 0.4)')
|
140 | 144 | parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
141 | 145 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
142 | 146 | parser.add_argument('--aug-splits', type=int, default=0,
|
143 | 147 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
| 148 | +parser.add_argument('--jsd', action='store_true', default=False, |
| 149 | + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') |
144 | 150 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
145 | 151 | help='Random erase prob (default: 0.)')
|
146 | 152 | parser.add_argument('--remode', type=str, default='const',
|
|
150 | 156 | parser.add_argument('--resplit', action='store_true', default=False,
|
151 | 157 | help='Do not random erase first (clean) augmentation split')
|
152 | 158 | parser.add_argument('--mixup', type=float, default=0.0,
|
153 |
| - help='mixup alpha, mixup enabled if > 0. (default: 0.)') |
| 159 | + help='Mixup alpha, mixup enabled if > 0. (default: 0.)') |
154 | 160 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
155 |
| - help='turn off mixup after this epoch, disabled if 0 (default: 0)') |
| 161 | + help='Turn off mixup after this epoch, disabled if 0 (default: 0)') |
156 | 162 | parser.add_argument('--smoothing', type=float, default=0.1,
|
157 |
| - help='label smoothing (default: 0.1)') |
| 163 | + help='Label smoothing (default: 0.1)') |
158 | 164 | parser.add_argument('--train-interpolation', type=str, default='random',
|
159 | 165 | help='Training interpolation (random, bilinear, bicubic default: "random")')
|
| 166 | +parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', |
| 167 | + help='Dropout rate (default: 0.)') |
| 168 | +parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', |
| 169 | + help='Drop connect rate, DEPRECATED, use drop-path (default: None)') |
| 170 | +parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', |
| 171 | + help='Drop path rate (default: None)') |
| 172 | +parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', |
| 173 | + help='Drop block rate (default: None)') |
| 174 | + |
160 | 175 | # Batch norm parameters (only works with gen_efficientnet based models currently)
|
161 | 176 | parser.add_argument('--bn-tf', action='store_true', default=False,
|
162 | 177 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
|
170 | 185 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
171 | 186 | parser.add_argument('--split-bn', action='store_true',
|
172 | 187 | help='Enable separate BN layers per augmentation split.')
|
| 188 | + |
173 | 189 | # Model Exponential Moving Average
|
174 | 190 | parser.add_argument('--model-ema', action='store_true', default=False,
|
175 | 191 | help='Enable tracking moving average of model weights')
|
176 | 192 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
177 | 193 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
178 | 194 | parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
179 | 195 | help='decay factor for model weights moving average (default: 0.9998)')
|
| 196 | + |
180 | 197 | # Misc
|
181 | 198 | parser.add_argument('--seed', type=int, default=42, metavar='S',
|
182 | 199 | help='random seed (default: 42)')
|
@@ -378,20 +395,28 @@ def main():
|
378 | 395 | if num_aug_splits > 1:
|
379 | 396 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
380 | 397 |
|
| 398 | + train_interpolation = args.train_interpolation |
| 399 | + if args.no_aug or not train_interpolation: |
| 400 | + train_interpolation = data_config['interpolation'] |
381 | 401 | loader_train = create_loader(
|
382 | 402 | dataset_train,
|
383 | 403 | input_size=data_config['input_size'],
|
384 | 404 | batch_size=args.batch_size,
|
385 | 405 | is_training=True,
|
386 | 406 | use_prefetcher=args.prefetcher,
|
| 407 | + no_aug=args.no_aug, |
387 | 408 | re_prob=args.reprob,
|
388 | 409 | re_mode=args.remode,
|
389 | 410 | re_count=args.recount,
|
390 | 411 | re_split=args.resplit,
|
| 412 | + scale=args.scale, |
| 413 | + ratio=args.ratio, |
| 414 | + hflip=args.hflip, |
| 415 | + vflip=args.vflip, |
391 | 416 | color_jitter=args.color_jitter,
|
392 | 417 | auto_augment=args.aa,
|
393 | 418 | num_aug_splits=num_aug_splits,
|
394 |
| - interpolation=args.train_interpolation, |
| 419 | + interpolation=train_interpolation, |
395 | 420 | mean=data_config['mean'],
|
396 | 421 | std=data_config['std'],
|
397 | 422 | num_workers=args.workers,
|
|
0 commit comments