Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 225823b

Browse files
markson14qubvel
andauthored
add timm-MobileNetV3 as an Encoder (#355)
* add timm-mobilenetv3 as encoder * fix import bug Co-authored-by: Pavel Yakubovskiy <qubvel@gmail.com>
1 parent 23a54b4 commit 225823b

File tree

4 files changed

+202
-1
lines changed

4 files changed

+202
-1
lines changed

‎README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 109 available encoders
15+
- 115 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -337,6 +337,22 @@ The following is a list of supported encoders in the SMP. Select the appropriate
337337
</div>
338338
</details>
339339

340+
<details>
341+
<summary style="margin-left: 25px;">MobileNetV3</summary>
342+
<div style="margin-left: 25px;">
343+
344+
|Encoder |Weights |Params, M |
345+
|--------------------------------|:------------------------------:|:------------------------------:|
346+
|timm-mobilenetv3_large_075 |imagenet |1.78M |
347+
|timm-mobilenetv3_large_100 |imagenet |2.97M |
348+
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
349+
|timm-mobilenetv3_small_075 |imagenet |0.57M |
350+
|timm-mobilenetv3_small_100 |imagenet |0.93M |
351+
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |
352+
353+
</div>
354+
</details>
355+
340356

341357
\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).
342358

‎docs/encoders.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,22 @@ VGG
316316
+-------------+------------+-------------+
317317
| vgg19\_bn | imagenet | 20M |
318318
+-------------+------------+-------------+
319+
320+
MobileNetV3
321+
~~~~~~~~~
322+
323+
+-----------------------------------+------------+-------------+
324+
| Encoder | Weights | Params, M |
325+
+===================================+============+=============+
326+
| timm-mobilenetv3_large_075 | imagenet | 1.78M |
327+
+-----------------------------------+------------+-------------+
328+
| timm-mobilenetv3_large_100 | imagenet | 2.97M |
329+
+-----------------------------------+------------+-------------+
330+
| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M |
331+
+-----------------------------------+------------+-------------+
332+
| timm-mobilenetv3_small_075 | imagenet | 0.57M |
333+
+-----------------------------------+------------+-------------+
334+
| timm-mobilenetv3_small_100 | imagenet | 0.93M |
335+
+-----------------------------------+------------+-------------+
336+
| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M |
337+
+-----------------------------------+------------+-------------+

‎segmentation_models_pytorch/encoders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .timm_res2net import timm_res2net_encoders
1818
from .timm_regnet import timm_regnet_encoders
1919
from .timm_sknet import timm_sknet_encoders
20+
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
2021
try:
2122
from .timm_gernet import timm_gernet_encoders
2223
except ImportError as e:
@@ -43,6 +44,7 @@
4344
encoders.update(timm_res2net_encoders)
4445
encoders.update(timm_regnet_encoders)
4546
encoders.update(timm_sknet_encoders)
47+
encoders.update(timm_mobilenetv3_encoders)
4648
encoders.update(timm_gernet_encoders)
4749

4850

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from timm import create_model
2+
import torch.nn as nn
3+
from ._base import EncoderMixin
4+
5+
6+
def make_divisible(x, divisible_by=8):
7+
import numpy as np
8+
return int(np.ceil(x * 1. / divisible_by) * divisible_by)
9+
10+
11+
class MobileNetV3Encoder(nn.Module, EncoderMixin):
12+
def __init__(self, model, width_mult, depth=5, **kwargs):
13+
super().__init__()
14+
self._depth = depth
15+
if 'small' in str(model):
16+
self.mode = 'small'
17+
self._out_channels = (16*width_mult, 16*width_mult, 24*width_mult, 48*width_mult, 576*width_mult)
18+
self._out_channels = tuple(map(make_divisible, self._out_channels))
19+
elif 'large' in str(model):
20+
self.mode = 'large'
21+
self._out_channels = (16*width_mult, 24*width_mult, 40*width_mult, 112*width_mult, 960*width_mult)
22+
self._out_channels = tuple(map(make_divisible, self._out_channels))
23+
else:
24+
self.mode = 'None'
25+
raise ValueError(
26+
'MobileNetV3 mode should be small or large, got {}'.format(self.mode))
27+
self._out_channels = (3,) + self._out_channels
28+
self._in_channels = 3
29+
# minimal models replace hardswish with relu
30+
model = create_model(model_name=model,
31+
scriptable=True, # torch.jit scriptable
32+
exportable=True, # onnx export
33+
features_only=True)
34+
self.conv_stem = model.conv_stem
35+
self.bn1 = model.bn1
36+
self.act1 = model.act1
37+
self.blocks = model.blocks
38+
39+
def get_stages(self):
40+
if self.mode == 'small':
41+
return [
42+
nn.Identity(),
43+
nn.Sequential(self.conv_stem, self.bn1, self.act1),
44+
self.blocks[0],
45+
self.blocks[1],
46+
self.blocks[2:4],
47+
self.blocks[4:],
48+
]
49+
elif self.mode == 'large':
50+
return [
51+
nn.Identity(),
52+
nn.Sequential(self.conv_stem, self.bn1, self.act1, self.blocks[0]),
53+
self.blocks[1],
54+
self.blocks[2],
55+
self.blocks[3:5],
56+
self.blocks[5:],
57+
]
58+
else:
59+
ValueError('MobileNetV3 mode should be small or large, got {}'.format(self.mode))
60+
61+
def forward(self, x):
62+
stages = self.get_stages()
63+
64+
features = []
65+
for i in range(self._depth + 1):
66+
x = stages[i](x)
67+
features.append(x)
68+
69+
return features
70+
71+
def load_state_dict(self, state_dict, **kwargs):
72+
state_dict.pop('conv_head.weight')
73+
state_dict.pop('conv_head.bias')
74+
state_dict.pop('classifier.weight')
75+
state_dict.pop('classifier.bias')
76+
super().load_state_dict(state_dict, **kwargs)
77+
78+
79+
mobilenetv3_weights = {
80+
'tf_mobilenetv3_large_075': {
81+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth'
82+
},
83+
'tf_mobilenetv3_large_100': {
84+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth'
85+
},
86+
'tf_mobilenetv3_large_minimal_100': {
87+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth'
88+
},
89+
'tf_mobilenetv3_small_075': {
90+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth'
91+
},
92+
'tf_mobilenetv3_small_100': {
93+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth'
94+
},
95+
'tf_mobilenetv3_small_minimal_100': {
96+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth'
97+
},
98+
99+
100+
}
101+
102+
pretrained_settings = {}
103+
for model_name, sources in mobilenetv3_weights.items():
104+
pretrained_settings[model_name] = {}
105+
for source_name, source_url in sources.items():
106+
pretrained_settings[model_name][source_name] = {
107+
"url": source_url,
108+
'input_range': [0, 1],
109+
'mean': [0.485, 0.456, 0.406],
110+
'std': [0.229, 0.224, 0.225],
111+
'input_space': 'RGB',
112+
}
113+
114+
115+
timm_mobilenetv3_encoders = {
116+
'timm-mobilenetv3_large_075': {
117+
'encoder': MobileNetV3Encoder,
118+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'],
119+
'params': {
120+
'model': 'tf_mobilenetv3_large_075',
121+
'width_mult': 0.75
122+
}
123+
},
124+
'timm-mobilenetv3_large_100': {
125+
'encoder': MobileNetV3Encoder,
126+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'],
127+
'params': {
128+
'model': 'tf_mobilenetv3_large_100',
129+
'width_mult': 1.0
130+
}
131+
},
132+
'timm-mobilenetv3_large_minimal_100': {
133+
'encoder': MobileNetV3Encoder,
134+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'],
135+
'params': {
136+
'model': 'tf_mobilenetv3_large_minimal_100',
137+
'width_mult': 1.0
138+
}
139+
},
140+
'timm-mobilenetv3_small_075': {
141+
'encoder': MobileNetV3Encoder,
142+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'],
143+
'params': {
144+
'model': 'tf_mobilenetv3_small_075',
145+
'width_mult': 0.75
146+
}
147+
},
148+
'timm-mobilenetv3_small_100': {
149+
'encoder': MobileNetV3Encoder,
150+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'],
151+
'params': {
152+
'model': 'tf_mobilenetv3_small_100',
153+
'width_mult': 1.0
154+
}
155+
},
156+
'timm-mobilenetv3_small_minimal_100': {
157+
'encoder': MobileNetV3Encoder,
158+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'],
159+
'params': {
160+
'model': 'tf_mobilenetv3_small_minimal_100',
161+
'width_mult': 1.0
162+
}
163+
},
164+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /