Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 23a54b4

Browse files
Vozfqubvel
andauthored
Genet from timm (#344)
* gernet from regnet * basic gernet * depth set to 5, and requirements+import update * docs * Fix summary error * remove input size * manet fix and test with latest timm Co-authored-by: Pavel Yakubovskiy <qubvel@gmail.com>
1 parent f91cc59 commit 23a54b4

File tree

6 files changed

+162
-5
lines changed

6 files changed

+162
-5
lines changed

‎.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
python -m pip install codecov pytest mock
3030
pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
3131
pip install .
32+
pip install -U git+https://github.com/rwightman/pytorch-image-models
3233
- name: Test
3334
run: |
3435
python -m pytest -s tests

‎README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 106 available encoders
15+
- 109 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -188,6 +188,19 @@ The following is a list of supported encoders in the SMP. Select the appropriate
188188
</div>
189189
</details>
190190

191+
<details>
192+
<summary style="margin-left: 25px;">GERNet</summary>
193+
<div style="margin-left: 25px;">
194+
195+
|Encoder |Weights |Params, M |
196+
|--------------------------------|:------------------------------:|:------------------------------:|
197+
|timm-gernet_s |imagenet |6M |
198+
|timm-gernet_m |imagenet |18M |
199+
|timm-gernet_l |imagenet |28M |
200+
201+
</div>
202+
</details>
203+
191204
<details>
192205
<summary style="margin-left: 25px;">SE-Net</summary>
193206
<div style="margin-left: 25px;">

‎docs/encoders.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,19 @@ RegNet(x/y)
136136
| timm-regnety\_320 | imagenet | 141M |
137137
+---------------------+------------+-------------+
138138

139+
GERNet
140+
~~~~~~
141+
142+
+-------------------------+------------+-------------+
143+
| Encoder | Weights | Params, M |
144+
+=========================+============+=============+
145+
| timm-gernet\_s | imagenet | 6M |
146+
+-------------------------+------------+-------------+
147+
| timm-gernet\_m | imagenet | 18M |
148+
+-------------------------+------------+-------------+
149+
| timm-gernet\_l | imagenet | 28M |
150+
+-------------------------+------------+-------------+
151+
139152
SE-Net
140153
~~~~~~
141154

‎segmentation_models_pytorch/encoders/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
from .timm_res2net import timm_res2net_encoders
1818
from .timm_regnet import timm_regnet_encoders
1919
from .timm_sknet import timm_sknet_encoders
20+
try:
21+
from .timm_gernet import timm_gernet_encoders
22+
except ImportError as e:
23+
timm_gernet_encoders = {}
24+
print("Current timm version doesn't support GERNet."
25+
"If GERNet support is needed please update timm")
26+
2027
from ._preprocessing import preprocess_input
2128

2229
encoders = {}
@@ -36,6 +43,7 @@
3643
encoders.update(timm_res2net_encoders)
3744
encoders.update(timm_regnet_encoders)
3845
encoders.update(timm_sknet_encoders)
46+
encoders.update(timm_gernet_encoders)
3947

4048

4149
def get_encoder(name, in_channels=3, depth=5, weights=None):
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from timm.models import ByobCfg, BlocksCfg, ByobNet
2+
3+
from ._base import EncoderMixin
4+
import torch.nn as nn
5+
6+
7+
class GERNetEncoder(ByobNet, EncoderMixin):
8+
def __init__(self, out_channels, depth=5, **kwargs):
9+
super().__init__(**kwargs)
10+
self._depth = depth
11+
self._out_channels = out_channels
12+
self._in_channels = 3
13+
14+
del self.head
15+
16+
def get_stages(self):
17+
return [
18+
nn.Identity(),
19+
self.stem,
20+
self.stages[0],
21+
self.stages[1],
22+
self.stages[2],
23+
nn.Sequential(self.stages[3], self.stages[4], self.final_conv)
24+
]
25+
26+
def forward(self, x):
27+
stages = self.get_stages()
28+
29+
features = []
30+
for i in range(self._depth + 1):
31+
x = stages[i](x)
32+
features.append(x)
33+
34+
return features
35+
36+
def load_state_dict(self, state_dict, **kwargs):
37+
state_dict.pop("head.fc.weight")
38+
state_dict.pop("head.fc.bias")
39+
super().load_state_dict(state_dict, **kwargs)
40+
41+
42+
regnet_weights = {
43+
'timm-gernet_s': {
44+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth',
45+
},
46+
'timm-gernet_m': {
47+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth',
48+
},
49+
'timm-gernet_l': {
50+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth',
51+
},
52+
}
53+
54+
pretrained_settings = {}
55+
for model_name, sources in regnet_weights.items():
56+
pretrained_settings[model_name] = {}
57+
for source_name, source_url in sources.items():
58+
pretrained_settings[model_name][source_name] = {
59+
"url": source_url,
60+
'input_range': [0, 1],
61+
'mean': [0.485, 0.456, 0.406],
62+
'std': [0.229, 0.224, 0.225],
63+
'num_classes': 1000
64+
}
65+
66+
timm_gernet_encoders = {
67+
'timm-gernet_s': {
68+
'encoder': GERNetEncoder,
69+
"pretrained_settings": pretrained_settings["timm-gernet_s"],
70+
'params': {
71+
'out_channels': (3, 13, 48, 48, 384, 1920),
72+
'cfg': ByobCfg(
73+
blocks=(
74+
BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
75+
BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
76+
BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
77+
BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
78+
BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
79+
),
80+
stem_chs=13,
81+
num_features=1920,
82+
)
83+
},
84+
},
85+
'timm-gernet_m': {
86+
'encoder': GERNetEncoder,
87+
"pretrained_settings": pretrained_settings["timm-gernet_m"],
88+
'params': {
89+
'out_channels': (3, 32, 128, 192, 640, 2560),
90+
'cfg': ByobCfg(
91+
blocks=(
92+
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
93+
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
94+
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
95+
BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
96+
BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
97+
),
98+
stem_chs=32,
99+
num_features=2560,
100+
)
101+
},
102+
},
103+
'timm-gernet_l': {
104+
'encoder': GERNetEncoder,
105+
"pretrained_settings": pretrained_settings["timm-gernet_l"],
106+
'params': {
107+
'out_channels': (3, 32, 128, 192, 640, 2560),
108+
'cfg': ByobCfg(
109+
blocks=(
110+
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
111+
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
112+
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
113+
BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
114+
BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
115+
),
116+
stem_chs=32,
117+
num_features=2560,
118+
)
119+
},
120+
},
121+
}

‎segmentation_models_pytorch/manet/decoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,19 @@ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True,
5656
use_batchnorm=use_batchnorm,
5757
)
5858
)
59+
reduced_channels = max(1, skip_channels // reduction)
5960
self.SE_ll = nn.Sequential(
6061
nn.AdaptiveAvgPool2d(1),
61-
nn.Conv2d(skip_channels, skip_channels//reduction, 1),
62+
nn.Conv2d(skip_channels, reduced_channels, 1),
6263
nn.ReLU(inplace=True),
63-
nn.Conv2d(skip_channels//reduction, skip_channels, 1),
64+
nn.Conv2d(reduced_channels, skip_channels, 1),
6465
nn.Sigmoid(),
6566
)
6667
self.SE_hl = nn.Sequential(
6768
nn.AdaptiveAvgPool2d(1),
68-
nn.Conv2d(skip_channels, skip_channels//reduction, 1),
69+
nn.Conv2d(skip_channels, reduced_channels, 1),
6970
nn.ReLU(inplace=True),
70-
nn.Conv2d(skip_channels//reduction, skip_channels, 1),
71+
nn.Conv2d(reduced_channels, skip_channels, 1),
7172
nn.Sigmoid(),
7273
)
7374
self.conv1 = md.Conv2dReLU(

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /