Skip to content

Commit 851c141

Browse files
committed
add docstrings to methods, reorder netidxs
1 parent fa4c4ea commit 851c141

File tree

1 file changed

+145
-31
lines changed

1 file changed

+145
-31
lines changed

ImageNet/simplenet.py

Lines changed: 145 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,21 @@ def __init__(
107107
):
108108
"""Instantiates a SimpleNet model. SimpleNet is comprised of the most basic building blocks of a CNN architecture.
109109
It uses basic principles to maximize the network performance both in terms of feature representation and speed without
110-
resorting to complex design or operators.
110+
resorting to complex design or operators.
111111
112112
Args:
113113
num_classes (int, optional): number of classes. Defaults to 1000.
114114
in_chans (int, optional): number of input channels. Defaults to 3.
115115
scale (float, optional): scale of the architecture width. Defaults to 1.0.
116116
network_idx (int, optional): the network index indicating the 5 million or 8 million version(0 and 1 respectively). Defaults to 0.
117-
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
118-
you can choose between 0 and 4. (note for imagenet use 1-4). Defaults to 2.
117+
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
118+
This is used for larger input sizes such as the 224x224 in imagenet training where the
119+
input size incurs a lot of overhead if not downsampled properly.
120+
you can choose between 0 meaning no change and 4. where each number denotes a specific
121+
downsampling strategy. For imagenet use 1-4.
122+
the larger the stride mode, the higher accuracy and the slower
123+
the network gets. stride mode 1 is the fastest and achives very good accuracy.
124+
Defaults to 2.
119125
drop_rates (Dict[int,float], optional): custom drop out rates specified per layer.
120126
each rate should be paired with the corrosponding layer index(pooling and cnn layers are counted only). Defaults to {}.
121127
"""
@@ -251,12 +257,13 @@ def __init__(
251257
self.in_chans = in_chans
252258
self.scale = scale
253259
self.networks = [
254-
"simplenet_cifar_310k", # 0
255-
"simplenet_cifar_460k", # 1
256-
"simplenet_cifar_5m", # 2
257-
"simplenet_cifar_5m_extra_pool", # 3
258-
"simplenetv1_imagenet", # 4
259-
"simplenetv1_imagenet_9m", # 5
260+
"simplenetv1_imagenet", # 0
261+
"simplenetv1_imagenet_9m", # 1
262+
# other archs
263+
"simplenet_cifar_310k", # 2
264+
"simplenet_cifar_460k", # 3
265+
"simplenet_cifar_5m", # 4
266+
"simplenet_cifar_5m_extra_pool", # 5
260267
]
261268
self.network_idx = network_idx
262269
self.mode = mode
@@ -326,7 +333,7 @@ def _gen_simplenet(
326333
num_classes: int = 1000,
327334
in_chans: int = 3,
328335
scale: float = 1.0,
329-
network_idx: int = 4,
336+
network_idx: int = 0,
330337
mode: int = 2,
331338
pretrained: bool = False,
332339
drop_rates: Dict[int, float] = {},
@@ -349,19 +356,36 @@ def _gen_simplenet(
349356

350357

351358
def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
359+
"""Generic simplenet model builder. by default it returns `simplenetv1_5m_m2` model
360+
but specifying different arguments such as `netidx`, `scale` or `mode` will result in
361+
the corrosponding network variant.
362+
363+
when pretrained is specified, if the combination of settings resemble any known variants
364+
specified in the `default_cfg`, their respective pretrained weights will be loaded, otherwise
365+
an exception will be thrown denoting Unknown model variant being specified.
366+
367+
Args:
368+
pretrained (bool, optional): loads the model with pretrained weights only if the model is a known variant specified in default_cfg. Defaults to False.
369+
370+
Raises:
371+
Exception: if pretrained is used with an unknown/custom model variant and exception is raised.
372+
373+
Returns:
374+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
375+
"""
352376
num_classes = kwargs.get("num_classes", 1000)
353377
in_chans = kwargs.get("in_chans", 3)
354378
scale = kwargs.get("scale", 1.0)
355-
network_idx = kwargs.get("network_idx", 4)
379+
network_idx = kwargs.get("network_idx", 0)
356380
mode = kwargs.get("mode", 2)
357381
drop_rates = kwargs.get("drop_rates", {})
358-
model_variant = "simplenetv1"
382+
model_variant = "simplenetv1_5m_m2"
359383
if pretrained:
360384
# check if the model specified is a known variant
361385
model_base = None
362-
if network_idx == 4:
386+
if network_idx == 0:
363387
model_base = 5
364-
elif network_idx == 5:
388+
elif network_idx == 1:
365389
model_base = 9
366390
config = ""
367391
if math.isclose(scale, 1.0):
@@ -372,37 +396,46 @@ def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
372396
config = f"small_m{mode}_05"
373397
else:
374398
config = f"m{mode}_{scale:.2f}".replace(".", "")
375-
376-
if network_idx == 0:
377-
model_variant = f"simplenetv1_{config}"
378-
else:
379-
model_variant = f"simplenetv1_{config}"
399+
model_variant = f"simplenetv1_{config}"
380400

381401
return _gen_simplenet(model_variant, num_classes, in_chans, scale, network_idx, mode, pretrained, drop_rates)
382402

383403

404+
def remove_network_settings(kwargs: Dict[str, Any]) -> Dict[str, Any]:
405+
"""Removes network related settings passed in kwargs for predefined network configruations below
406+
407+
Returns:
408+
Dict[str,Any]: cleaned kwargs
409+
"""
410+
model_args = {k: v for k, v in kwargs.items() if k not in ["scale", "network_idx", "mode"]}
411+
return model_args
412+
413+
384414
# cifar10/100 models
385415
def simplenet_cifar_310k(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
386416
"""original implementation of smaller variants of simplenet for cifar10/100
387417
that were used in the paper
388418
"""
389419
model_variant = "simplenet_cifar_310k"
390-
return _gen_simplenet(model_variant, network_idx=0, mode=0, pretrained=pretrained, **kwargs)
420+
model_args = remove_network_settings(kwargs)
421+
return _gen_simplenet(model_variant, network_idx=2, mode=0, pretrained=pretrained, **model_args)
391422

392423

393424
def simplenet_cifar_460k(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
394425
"""original implementation of smaller variants of simplenet for cifar10/100
395426
that were used in the paper
396427
"""
397428
model_variant = "simplenet_cifar_460k"
398-
return _gen_simplenet(model_variant, network_idx=1, mode=0, pretrained=pretrained, **kwargs)
429+
model_args = remove_network_settings(kwargs)
430+
return _gen_simplenet(model_variant, network_idx=3, mode=0, pretrained=pretrained, **model_args)
399431

400432

401433
def simplenet_cifar_5m(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
402434
"""The original implementation of simplenet trained on cifar10/100 in caffe.
403435
"""
404436
model_variant = "simplenet_cifar_5m"
405-
return _gen_simplenet(model_variant, network_idx=2, mode=0, pretrained=pretrained, **kwargs)
437+
model_args = remove_network_settings(kwargs)
438+
return _gen_simplenet(model_variant, network_idx=4, mode=0, pretrained=pretrained, **model_args)
406439

407440

408441
def simplenet_cifar_5m_extra_pool(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
@@ -411,48 +444,129 @@ def simplenet_cifar_5m_extra_pool(pretrained: bool = False, **kwargs: Any) -> Si
411444
this is just here to be able to load the weights that were trained using this variation still available on the repository.
412445
"""
413446
model_variant = "simplenet_cifar_5m_extra_pool"
414-
return _gen_simplenet(model_variant, network_idx=3, mode=0, pretrained=pretrained, **kwargs)
447+
model_args = remove_network_settings(kwargs)
448+
return _gen_simplenet(model_variant, network_idx=5, mode=0, pretrained=pretrained, **model_args)
415449

416450

417451
# imagenet models
418452
def simplenetv1_small_m1_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
453+
"""Creates a small variant of simplenetv1_5m, with 1.5m parameters. This uses m1 stride mode
454+
which makes it the fastest variant available.
455+
456+
Args:
457+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
458+
459+
Returns:
460+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
461+
"""
419462
model_variant = "simplenetv1_small_m1_05"
420-
return _gen_simplenet(model_variant, scale=0.5, network_idx=4, mode=1, pretrained=pretrained, **kwargs)
463+
model_args = remove_network_settings(kwargs)
464+
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=1, pretrained=pretrained, **model_args)
421465

422466

423467
def simplenetv1_small_m2_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
468+
"""Creates a second small variant of simplenetv1_5m, with 1.5m parameters. This uses m2 stride mode
469+
which makes it the second fastest variant available.
470+
471+
Args:
472+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
473+
474+
Returns:
475+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
476+
"""
424477
model_variant = "simplenetv1_small_m2_05"
425-
return _gen_simplenet(model_variant, scale=0.5, network_idx=4, mode=2, pretrained=pretrained, **kwargs)
478+
model_args = remove_network_settings(kwargs)
479+
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=2, pretrained=pretrained, **model_args)
426480

427481

428482
def simplenetv1_small_m1_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
483+
"""Creates a third small variant of simplenetv1_5m, with 3m parameters. This uses m1 stride mode
484+
which makes it the third fastest variant available.
485+
486+
Args:
487+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
488+
489+
Returns:
490+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
491+
"""
429492
model_variant = "simplenetv1_small_m1_075"
430-
return _gen_simplenet(model_variant, scale=0.75, network_idx=4, mode=1, pretrained=pretrained, **kwargs)
493+
model_args = remove_network_settings(kwargs)
494+
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=1, pretrained=pretrained, **model_args)
431495

432496

433497
def simplenetv1_small_m2_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
498+
"""Creates a forth small variant of simplenetv1_5m, with 3m parameters. This uses m2 stride mode
499+
which makes it the forth fastest variant available.
500+
501+
Args:
502+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
503+
504+
Returns:
505+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
506+
"""
434507
model_variant = "simplenetv1_small_m2_075"
435-
return _gen_simplenet(model_variant, scale=0.75, network_idx=4, mode=2, pretrained=pretrained, **kwargs)
508+
model_args = remove_network_settings(kwargs)
509+
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=2, pretrained=pretrained, **model_args)
436510

437511

438512
def simplenetv1_5m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
513+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m1 stride mode
514+
which makes it a fast and performant model.
515+
516+
Args:
517+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
518+
519+
Returns:
520+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
521+
"""
439522
model_variant = "simplenetv1_5m_m1"
440-
return _gen_simplenet(model_variant, scale=1.0, network_idx=4, mode=1, pretrained=pretrained, **kwargs)
523+
model_args = remove_network_settings(kwargs)
524+
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=1, pretrained=pretrained, **model_args)
441525

442526

443527
def simplenetv1_5m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
528+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m2 stride mode
529+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
530+
531+
Args:
532+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
533+
534+
Returns:
535+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
536+
"""
444537
model_variant = "simplenetv1_5m_m2"
445-
return _gen_simplenet(model_variant, scale=1.0, network_idx=4, mode=2, pretrained=pretrained, **kwargs)
538+
model_args = remove_network_settings(kwargs)
539+
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=2, pretrained=pretrained, **model_args)
446540

447541

448542
def simplenetv1_9m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
543+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m1 stride mode
544+
which makes it run faster.
545+
546+
Args:
547+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
548+
549+
Returns:
550+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
551+
"""
449552
model_variant = "simplenetv1_9m_m1"
450-
return _gen_simplenet(model_variant, scale=1.0, network_idx=5, mode=1, pretrained=pretrained, **kwargs)
553+
model_args = remove_network_settings(kwargs)
554+
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=1, pretrained=pretrained, **model_args)
451555

452556

453557
def simplenetv1_9m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
558+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m2 stride mode
559+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
560+
561+
Args:
562+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
563+
564+
Returns:
565+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
566+
"""
454567
model_variant = "simplenetv1_9m_m2"
455-
return _gen_simplenet(model_variant, scale=1.0, network_idx=5, mode=2, pretrained=pretrained, **kwargs)
568+
model_args = remove_network_settings(kwargs)
569+
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=2, pretrained=pretrained, **model_args)
456570

457571

458572
if __name__ == "__main__":

0 commit comments

Comments
 (0)