Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d0b5217

Browse files
committed
update Simplenet constructors docstring and remove commented codes
1 parent 851c141 commit d0b5217

File tree

1 file changed

+112
-38
lines changed
  • ImageNet/training_scripts/imagenet_training/timm/models

1 file changed

+112
-38
lines changed

‎ImageNet/training_scripts/imagenet_training/timm/models/simplenet.py

Lines changed: 112 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,12 @@
1515
Official Pythorch impl at https://github.com/Coderx7/SimpleNet_Pytorch
1616
Seyyed Hossein Hasanpour
1717
"""
18-
# import os
1918
import math
2019

2120
import torch
2221
import torch.nn as nn
2322
import torch.nn.functional as F
2423

25-
# from torch.hub import download_url_to_file
26-
2724
from typing import Union, Tuple, List, Dict, Any, cast, Optional
2825

2926
from .helpers import build_model_with_cfg
@@ -79,10 +76,10 @@ def _cfg(url="", **kwargs):
7976
"simplenetv1_5m_m2": _cfg(
8077
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0-alpha/simv1_5m_m2-324ba7cc.pth"
8178
),
82-
"simplenetv1_m1_9m": _cfg(
79+
"simplenetv1_9m_m1": _cfg(
8380
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0-alpha/simv1_9m_m1-00000000.pth"
8481
),
85-
"simplenetv1_m2_9m": _cfg(
82+
"simplenetv1_9m_m2": _cfg(
8683
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0-alpha/simv1_9m_m2-00000000.pth"
8784
),
8885
}
@@ -106,15 +103,21 @@ def __init__(
106103
):
107104
"""Instantiates a SimpleNet model. SimpleNet is comprised of the most basic building blocks of a CNN architecture.
108105
It uses basic principles to maximize the network performance both in terms of feature representation and speed without
109-
resorting to complex design or operators.
106+
resorting to complex design or operators.
110107
111108
Args:
112109
num_classes (int, optional): number of classes. Defaults to 1000.
113110
in_chans (int, optional): number of input channels. Defaults to 3.
114111
scale (float, optional): scale of the architecture width. Defaults to 1.0.
115112
network_idx (int, optional): the network index indicating the 5 million or 8 million version(0 and 1 respectively). Defaults to 0.
116-
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
117-
you can choose between 0 and 4. (note for imagenet use 1-4). Defaults to 2.
113+
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
114+
This is used for larger input sizes such as the 224x224 in imagenet training where the
115+
input size incurs a lot of overhead if not downsampled properly.
116+
you can choose between 0 meaning no change and 4. where each number denotes a specific
117+
downsampling strategy. For imagenet use 1-4.
118+
the larger the stride mode, the higher accuracy and the slower
119+
the network gets. stride mode 1 is the fastest and achives very good accuracy.
120+
Defaults to 2.
118121
drop_rates (Dict[int,float], optional): custom drop out rates specified per layer.
119122
each rate should be paired with the corrosponding layer index(pooling and cnn layers are counted only). Defaults to {}.
120123
"""
@@ -333,22 +336,23 @@ def set_grad_checkpointing(self, enable=True):
333336
def get_classifier(self):
334337
return self.classifier
335338

336-
def reset_classifier(self, num_classes, network_idx=0, scale=1.0):
339+
def reset_classifier(self, num_classes: int):
337340
self.num_classes = num_classes
338-
self.classifier = nn.Linear(round(self.cfg[self.networks[network_idx]][-1][1] * scale), num_classes)
341+
self.classifier = nn.Linear(round(self.cfg[self.networks[self.network_idx]][-1][1] * self.scale), num_classes)
339342

340343
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
341-
x = self.features(x)
342-
x = F.max_pool2d(x, kernel_size=x.size()[2:])
343-
x = F.dropout2d(x, self.last_dropout_rate, training=self.training)
344-
x = x.view(x.size(0), -1)
345-
return x
344+
return self.features(x)
346345

347346
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
348-
x = self.forward_features(x)
349-
return x if pre_logits else self.classifier(x)
347+
x_ = self.forward_features(x)
348+
if pre_logits:
349+
return x
350+
else:
351+
x = F.max_pool2d(x, kernel_size=x.size()[2:])
352+
x = F.dropout2d(x, self.last_dropout_rate, training=self.training)
353+
x = x.view(x.size(0), -1)
354+
return self.classifier(x)
350355

351-
!Test this after this change, and update the pure pytorch version and cifar10 versions as well- test classification test extensively again
352356
def _gen_simplenet(
353357
model_variant: str = "simplenetv1_m2",
354358
num_classes: int = 1000,
@@ -371,26 +375,28 @@ def _gen_simplenet(
371375
**kwargs,
372376
)
373377
model = build_model_with_cfg(SimpleNet, model_variant, pretrained, **model_args)
374-
375-
# model = SimpleNet(num_classes, in_chans, scale=scale, network_idx=network_idx, mode=mode, drop_rates=drop_rates)
376-
# if pretrained:
377-
# cfg = default_cfgs.get(model_variant, None)
378-
# if cfg is None:
379-
# raise Exception(f"Unknown model variant ('{model_variant}') specified!")
380-
# url = cfg["url"]
381-
# checkpoint_filename = url.split("/")[-1]
382-
# checkpoint_path = f"tmp/{checkpoint_filename}"
383-
# print(f"saving in checkpoint_path:{checkpoint_path}")
384-
# if not os.path.exists(checkpoint_path):
385-
# os.makedirs("tmp")
386-
# download_url_to_file(url, checkpoint_path)
387-
# checkpoint = torch.load(checkpoint_path, map_location="cpu",)
388-
# model.load_state_dict(checkpoint)
389378
return model
390379

391380

392381
@register_model
393382
def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
383+
"""Generic simplenet model builder. by default it returns `simplenetv1_5m_m2` model
384+
but specifying different arguments such as `netidx`, `scale` or `mode` will result in
385+
the corrosponding network variant.
386+
387+
when pretrained is specified, if the combination of settings resemble any known variants
388+
specified in the `default_cfg`, their respective pretrained weights will be loaded, otherwise
389+
an exception will be thrown denoting Unknown model variant being specified.
390+
391+
Args:
392+
pretrained (bool, optional): loads the model with pretrained weights only if the model is a known variant specified in default_cfg. Defaults to False.
393+
394+
Raises:
395+
Exception: if pretrained is used with an unknown/custom model variant and exception is raised.
396+
397+
Returns:
398+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
399+
"""
394400
num_classes = kwargs.get("num_classes", 1000)
395401
in_chans = kwargs.get("in_chans", 3)
396402
scale = kwargs.get("scale", 1.0)
@@ -414,11 +420,7 @@ def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
414420
config = f"small_m{mode}_05"
415421
else:
416422
config = f"m{mode}_{scale:.2f}".replace(".", "")
417-
418-
if network_idx == 0:
419-
model_variant = f"simplenetv1_{config}"
420-
else:
421-
model_variant = f"simplenetv1_{config}"
423+
model_variant = f"simplenetv1_{config}"
422424

423425
cfg = default_cfgs.get(model_variant, None)
424426
if cfg is None:
@@ -477,55 +479,127 @@ def remove_network_settings(kwargs: Dict[str,Any]) -> Dict[str,Any]:
477479
# imagenet models
478480
@register_model
479481
def simplenetv1_small_m1_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
482+
"""Creates a small variant of simplenetv1_5m, with 1.5m parameters. This uses m1 stride mode
483+
which makes it the fastest variant available.
484+
485+
Args:
486+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
487+
488+
Returns:
489+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
490+
"""
480491
model_variant = "simplenetv1_small_m1_05"
481492
model_args = remove_network_settings(kwargs)
482493
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=1, pretrained=pretrained, **model_args)
483494

484495

485496
@register_model
486497
def simplenetv1_small_m2_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
498+
"""Creates a second small variant of simplenetv1_5m, with 1.5m parameters. This uses m2 stride mode
499+
which makes it the second fastest variant available.
500+
501+
Args:
502+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
503+
504+
Returns:
505+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
506+
"""
487507
model_variant = "simplenetv1_small_m2_05"
488508
model_args = remove_network_settings(kwargs)
489509
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=2, pretrained=pretrained, **model_args)
490510

491511

492512
@register_model
493513
def simplenetv1_small_m1_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
514+
"""Creates a third small variant of simplenetv1_5m, with 3m parameters. This uses m1 stride mode
515+
which makes it the third fastest variant available.
516+
517+
Args:
518+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
519+
520+
Returns:
521+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
522+
"""
494523
model_variant = "simplenetv1_small_m1_075"
495524
model_args = remove_network_settings(kwargs)
496525
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=1, pretrained=pretrained, **model_args)
497526

498527

499528
@register_model
500529
def simplenetv1_small_m2_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
530+
"""Creates a forth small variant of simplenetv1_5m, with 3m parameters. This uses m2 stride mode
531+
which makes it the forth fastest variant available.
532+
533+
Args:
534+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
535+
536+
Returns:
537+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
538+
"""
501539
model_variant = "simplenetv1_small_m2_075"
502540
model_args = remove_network_settings(kwargs)
503541
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=2, pretrained=pretrained, **model_args)
504542

505543

506544
@register_model
507545
def simplenetv1_5m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
546+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m1 stride mode
547+
which makes it a fast and performant model.
548+
549+
Args:
550+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
551+
552+
Returns:
553+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
554+
"""
508555
model_variant = "simplenetv1_5m_m1"
509556
model_args = remove_network_settings(kwargs)
510557
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=1, pretrained=pretrained, **model_args)
511558

512559

513560
@register_model
514561
def simplenetv1_5m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
562+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m2 stride mode
563+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
564+
565+
Args:
566+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
567+
568+
Returns:
569+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
570+
"""
515571
model_variant = "simplenetv1_5m_m2"
516572
model_args = remove_network_settings(kwargs)
517573
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=2, pretrained=pretrained, **model_args)
518574

519575

520576
@register_model
521577
def simplenetv1_9m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
578+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m1 stride mode
579+
which makes it run faster.
580+
581+
Args:
582+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
583+
584+
Returns:
585+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
586+
"""
522587
model_variant = "simplenetv1_9m_m1"
523588
model_args = remove_network_settings(kwargs)
524589
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=1, pretrained=pretrained, **model_args)
525590

526591

527592
@register_model
528593
def simplenetv1_9m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
594+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m2 stride mode
595+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
596+
597+
Args:
598+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
599+
600+
Returns:
601+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
602+
"""
529603
model_variant = "simplenetv1_9m_m2"
530604
model_args = remove_network_settings(kwargs)
531605
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=2, pretrained=pretrained, **model_args)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /