Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 851c141

Browse files
committed
add docstrings to methods, reorder netidxs
1 parent fa4c4ea commit 851c141

File tree

1 file changed

+145
-31
lines changed

1 file changed

+145
-31
lines changed

‎ImageNet/simplenet.py

Lines changed: 145 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,21 @@ def __init__(
107107
):
108108
"""Instantiates a SimpleNet model. SimpleNet is comprised of the most basic building blocks of a CNN architecture.
109109
It uses basic principles to maximize the network performance both in terms of feature representation and speed without
110-
resorting to complex design or operators.
110+
resorting to complex design or operators.
111111
112112
Args:
113113
num_classes (int, optional): number of classes. Defaults to 1000.
114114
in_chans (int, optional): number of input channels. Defaults to 3.
115115
scale (float, optional): scale of the architecture width. Defaults to 1.0.
116116
network_idx (int, optional): the network index indicating the 5 million or 8 million version(0 and 1 respectively). Defaults to 0.
117-
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
118-
you can choose between 0 and 4. (note for imagenet use 1-4). Defaults to 2.
117+
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
118+
This is used for larger input sizes such as the 224x224 in imagenet training where the
119+
input size incurs a lot of overhead if not downsampled properly.
120+
you can choose between 0 meaning no change and 4. where each number denotes a specific
121+
downsampling strategy. For imagenet use 1-4.
122+
the larger the stride mode, the higher accuracy and the slower
123+
the network gets. stride mode 1 is the fastest and achives very good accuracy.
124+
Defaults to 2.
119125
drop_rates (Dict[int,float], optional): custom drop out rates specified per layer.
120126
each rate should be paired with the corrosponding layer index(pooling and cnn layers are counted only). Defaults to {}.
121127
"""
@@ -251,12 +257,13 @@ def __init__(
251257
self.in_chans = in_chans
252258
self.scale = scale
253259
self.networks = [
254-
"simplenet_cifar_310k", # 0
255-
"simplenet_cifar_460k", # 1
256-
"simplenet_cifar_5m", # 2
257-
"simplenet_cifar_5m_extra_pool", # 3
258-
"simplenetv1_imagenet", # 4
259-
"simplenetv1_imagenet_9m", # 5
260+
"simplenetv1_imagenet", # 0
261+
"simplenetv1_imagenet_9m", # 1
262+
# other archs
263+
"simplenet_cifar_310k", # 2
264+
"simplenet_cifar_460k", # 3
265+
"simplenet_cifar_5m", # 4
266+
"simplenet_cifar_5m_extra_pool", # 5
260267
]
261268
self.network_idx = network_idx
262269
self.mode = mode
@@ -326,7 +333,7 @@ def _gen_simplenet(
326333
num_classes: int = 1000,
327334
in_chans: int = 3,
328335
scale: float = 1.0,
329-
network_idx: int = 4,
336+
network_idx: int = 0,
330337
mode: int = 2,
331338
pretrained: bool = False,
332339
drop_rates: Dict[int, float] = {},
@@ -349,19 +356,36 @@ def _gen_simplenet(
349356

350357

351358
def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
359+
"""Generic simplenet model builder. by default it returns `simplenetv1_5m_m2` model
360+
but specifying different arguments such as `netidx`, `scale` or `mode` will result in
361+
the corrosponding network variant.
362+
363+
when pretrained is specified, if the combination of settings resemble any known variants
364+
specified in the `default_cfg`, their respective pretrained weights will be loaded, otherwise
365+
an exception will be thrown denoting Unknown model variant being specified.
366+
367+
Args:
368+
pretrained (bool, optional): loads the model with pretrained weights only if the model is a known variant specified in default_cfg. Defaults to False.
369+
370+
Raises:
371+
Exception: if pretrained is used with an unknown/custom model variant and exception is raised.
372+
373+
Returns:
374+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
375+
"""
352376
num_classes = kwargs.get("num_classes", 1000)
353377
in_chans = kwargs.get("in_chans", 3)
354378
scale = kwargs.get("scale", 1.0)
355-
network_idx = kwargs.get("network_idx", 4)
379+
network_idx = kwargs.get("network_idx", 0)
356380
mode = kwargs.get("mode", 2)
357381
drop_rates = kwargs.get("drop_rates", {})
358-
model_variant = "simplenetv1"
382+
model_variant = "simplenetv1_5m_m2"
359383
if pretrained:
360384
# check if the model specified is a known variant
361385
model_base = None
362-
if network_idx == 4:
386+
if network_idx == 0:
363387
model_base = 5
364-
elif network_idx == 5:
388+
elif network_idx == 1:
365389
model_base = 9
366390
config = ""
367391
if math.isclose(scale, 1.0):
@@ -372,37 +396,46 @@ def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
372396
config = f"small_m{mode}_05"
373397
else:
374398
config = f"m{mode}_{scale:.2f}".replace(".", "")
375-
376-
if network_idx == 0:
377-
model_variant = f"simplenetv1_{config}"
378-
else:
379-
model_variant = f"simplenetv1_{config}"
399+
model_variant = f"simplenetv1_{config}"
380400

381401
return _gen_simplenet(model_variant, num_classes, in_chans, scale, network_idx, mode, pretrained, drop_rates)
382402

383403

404+
def remove_network_settings(kwargs: Dict[str, Any]) -> Dict[str, Any]:
405+
"""Removes network related settings passed in kwargs for predefined network configruations below
406+
407+
Returns:
408+
Dict[str,Any]: cleaned kwargs
409+
"""
410+
model_args = {k: v for k, v in kwargs.items() if k not in ["scale", "network_idx", "mode"]}
411+
return model_args
412+
413+
384414
# cifar10/100 models
385415
def simplenet_cifar_310k(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
386416
"""original implementation of smaller variants of simplenet for cifar10/100
387417
that were used in the paper
388418
"""
389419
model_variant = "simplenet_cifar_310k"
390-
return _gen_simplenet(model_variant, network_idx=0, mode=0, pretrained=pretrained, **kwargs)
420+
model_args = remove_network_settings(kwargs)
421+
return _gen_simplenet(model_variant, network_idx=2, mode=0, pretrained=pretrained, **model_args)
391422

392423

393424
def simplenet_cifar_460k(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
394425
"""original implementation of smaller variants of simplenet for cifar10/100
395426
that were used in the paper
396427
"""
397428
model_variant = "simplenet_cifar_460k"
398-
return _gen_simplenet(model_variant, network_idx=1, mode=0, pretrained=pretrained, **kwargs)
429+
model_args = remove_network_settings(kwargs)
430+
return _gen_simplenet(model_variant, network_idx=3, mode=0, pretrained=pretrained, **model_args)
399431

400432

401433
def simplenet_cifar_5m(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
402434
"""The original implementation of simplenet trained on cifar10/100 in caffe.
403435
"""
404436
model_variant = "simplenet_cifar_5m"
405-
return _gen_simplenet(model_variant, network_idx=2, mode=0, pretrained=pretrained, **kwargs)
437+
model_args = remove_network_settings(kwargs)
438+
return _gen_simplenet(model_variant, network_idx=4, mode=0, pretrained=pretrained, **model_args)
406439

407440

408441
def simplenet_cifar_5m_extra_pool(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
@@ -411,48 +444,129 @@ def simplenet_cifar_5m_extra_pool(pretrained: bool = False, **kwargs: Any) -> Si
411444
this is just here to be able to load the weights that were trained using this variation still available on the repository.
412445
"""
413446
model_variant = "simplenet_cifar_5m_extra_pool"
414-
return _gen_simplenet(model_variant, network_idx=3, mode=0, pretrained=pretrained, **kwargs)
447+
model_args = remove_network_settings(kwargs)
448+
return _gen_simplenet(model_variant, network_idx=5, mode=0, pretrained=pretrained, **model_args)
415449

416450

417451
# imagenet models
418452
def simplenetv1_small_m1_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
453+
"""Creates a small variant of simplenetv1_5m, with 1.5m parameters. This uses m1 stride mode
454+
which makes it the fastest variant available.
455+
456+
Args:
457+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
458+
459+
Returns:
460+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
461+
"""
419462
model_variant = "simplenetv1_small_m1_05"
420-
return _gen_simplenet(model_variant, scale=0.5, network_idx=4, mode=1, pretrained=pretrained, **kwargs)
463+
model_args = remove_network_settings(kwargs)
464+
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=1, pretrained=pretrained, **model_args)
421465

422466

423467
def simplenetv1_small_m2_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
468+
"""Creates a second small variant of simplenetv1_5m, with 1.5m parameters. This uses m2 stride mode
469+
which makes it the second fastest variant available.
470+
471+
Args:
472+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
473+
474+
Returns:
475+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
476+
"""
424477
model_variant = "simplenetv1_small_m2_05"
425-
return _gen_simplenet(model_variant, scale=0.5, network_idx=4, mode=2, pretrained=pretrained, **kwargs)
478+
model_args = remove_network_settings(kwargs)
479+
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=2, pretrained=pretrained, **model_args)
426480

427481

428482
def simplenetv1_small_m1_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
483+
"""Creates a third small variant of simplenetv1_5m, with 3m parameters. This uses m1 stride mode
484+
which makes it the third fastest variant available.
485+
486+
Args:
487+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
488+
489+
Returns:
490+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
491+
"""
429492
model_variant = "simplenetv1_small_m1_075"
430-
return _gen_simplenet(model_variant, scale=0.75, network_idx=4, mode=1, pretrained=pretrained, **kwargs)
493+
model_args = remove_network_settings(kwargs)
494+
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=1, pretrained=pretrained, **model_args)
431495

432496

433497
def simplenetv1_small_m2_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
498+
"""Creates a forth small variant of simplenetv1_5m, with 3m parameters. This uses m2 stride mode
499+
which makes it the forth fastest variant available.
500+
501+
Args:
502+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
503+
504+
Returns:
505+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
506+
"""
434507
model_variant = "simplenetv1_small_m2_075"
435-
return _gen_simplenet(model_variant, scale=0.75, network_idx=4, mode=2, pretrained=pretrained, **kwargs)
508+
model_args = remove_network_settings(kwargs)
509+
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=2, pretrained=pretrained, **model_args)
436510

437511

438512
def simplenetv1_5m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
513+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m1 stride mode
514+
which makes it a fast and performant model.
515+
516+
Args:
517+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
518+
519+
Returns:
520+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
521+
"""
439522
model_variant = "simplenetv1_5m_m1"
440-
return _gen_simplenet(model_variant, scale=1.0, network_idx=4, mode=1, pretrained=pretrained, **kwargs)
523+
model_args = remove_network_settings(kwargs)
524+
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=1, pretrained=pretrained, **model_args)
441525

442526

443527
def simplenetv1_5m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
528+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m2 stride mode
529+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
530+
531+
Args:
532+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
533+
534+
Returns:
535+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
536+
"""
444537
model_variant = "simplenetv1_5m_m2"
445-
return _gen_simplenet(model_variant, scale=1.0, network_idx=4, mode=2, pretrained=pretrained, **kwargs)
538+
model_args = remove_network_settings(kwargs)
539+
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=2, pretrained=pretrained, **model_args)
446540

447541

448542
def simplenetv1_9m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
543+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m1 stride mode
544+
which makes it run faster.
545+
546+
Args:
547+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
548+
549+
Returns:
550+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
551+
"""
449552
model_variant = "simplenetv1_9m_m1"
450-
return _gen_simplenet(model_variant, scale=1.0, network_idx=5, mode=1, pretrained=pretrained, **kwargs)
553+
model_args = remove_network_settings(kwargs)
554+
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=1, pretrained=pretrained, **model_args)
451555

452556

453557
def simplenetv1_9m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
558+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m2 stride mode
559+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
560+
561+
Args:
562+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
563+
564+
Returns:
565+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
566+
"""
454567
model_variant = "simplenetv1_9m_m2"
455-
return _gen_simplenet(model_variant, scale=1.0, network_idx=5, mode=2, pretrained=pretrained, **kwargs)
568+
model_args = remove_network_settings(kwargs)
569+
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=2, pretrained=pretrained, **model_args)
456570

457571

458572
if __name__ == "__main__":

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /