Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 914f2bf

Browse files
authored
Release 0.2.0 (#430)
* new in_channels != 3 initialization * docs fixes * version resolving
1 parent 225823b commit 914f2bf

32 files changed

+233
-366
lines changed

‎.github/workflows/tests.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ jobs:
2929
python -m pip install codecov pytest mock
3030
pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
3131
pip install .
32-
pip install -U git+https://github.com/rwightman/pytorch-image-models
3332
- name: Test
3433
run: |
3534
python -m pytest -s tests

‎README.md

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 115 available encoders
15+
- 113 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -297,8 +297,12 @@ The following is a list of supported encoders in the SMP. Select the appropriate
297297
|Encoder |Weights |Params, M |
298298
|--------------------------------|:------------------------------:|:------------------------------:|
299299
|mobilenet_v2 |imagenet |2M |
300-
|mobilenet_v3_large |imagenet |3M |
301-
|mobilenet_v3_small |imagenet |1M |
300+
|timm-mobilenetv3_large_075 |imagenet |1.78M |
301+
|timm-mobilenetv3_large_100 |imagenet |2.97M |
302+
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
303+
|timm-mobilenetv3_small_075 |imagenet |0.57M |
304+
|timm-mobilenetv3_small_100 |imagenet |0.93M |
305+
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |
302306

303307
</div>
304308
</details>
@@ -337,22 +341,6 @@ The following is a list of supported encoders in the SMP. Select the appropriate
337341
</div>
338342
</details>
339343

340-
<details>
341-
<summary style="margin-left: 25px;">MobileNetV3</summary>
342-
<div style="margin-left: 25px;">
343-
344-
|Encoder |Weights |Params, M |
345-
|--------------------------------|:------------------------------:|:------------------------------:|
346-
|timm-mobilenetv3_large_075 |imagenet |1.78M |
347-
|timm-mobilenetv3_large_100 |imagenet |2.97M |
348-
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
349-
|timm-mobilenetv3_small_075 |imagenet |0.57M |
350-
|timm-mobilenetv3_small_100 |imagenet |0.93M |
351-
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |
352-
353-
</div>
354-
</details>
355-
356344

357345
\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).
358346

@@ -367,8 +355,9 @@ The following is a list of supported encoders in the SMP. Select the appropriate
367355

368356
##### Input channels
369357
Input channels parameter allows you to create models, which process tensors with arbitrary number of channels.
370-
If you use pretrained weights from imagenet - weights of first convolution will be reused for
371-
1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
358+
If you use pretrained weights from imagenet - weights of first convolution will be reused. For
359+
1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be
360+
populated with weights like `new_weight[:, i] = pretrained_weight[:, i % 3]` and than scaled with `new_weight * 3 / new_in_channels`.
372361
```python
373362
model = smp.FPN('resnet34', in_channels=1)
374363
mask = model(torch.ones([1, 1, 64, 64]))

‎docs/encoders.rst

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,23 @@ EfficientNet
265265
MobileNet
266266
~~~~~~~~~
267267

268-
+---------------------+------------+-------------+
269-
| Encoder | Weights | Params, M |
270-
+=====================+============+=============+
271-
| mobilenet\_v2 | imagenet | 2M |
272-
+---------------------+------------+-------------+
273-
| mobilenet\_v3_large | imagenet | 3M |
274-
+---------------------+------------+-------------+
275-
| mobilenet\_v2_small | imagenet | 1M |
276-
+---------------------+------------+-------------+
268+
+---------------------------------------+------------+-------------+
269+
| Encoder | Weights | Params, M |
270+
+=======================================+============+=============+
271+
| mobilenet\_v2 | imagenet | 2M |
272+
+---------------------------------------+------------+-------------+
273+
| timm-mobilenetv3\_large\_075 | imagenet | 1.78M |
274+
+---------------------------------------+------------+-------------+
275+
| timm-mobilenetv3\_large\_100 | imagenet | 2.97M |
276+
+---------------------------------------+------------+-------------+
277+
| timm-mobilenetv3\_large\_minimal\_100 | imagenet | 1.41M |
278+
+---------------------------------------+------------+-------------+
279+
| timm-mobilenetv3\_small\_075 | imagenet | 0.57M |
280+
+---------------------------------------+------------+-------------+
281+
| timm-mobilenetv3\_small\_100 | imagenet | 0.93M |
282+
+---------------------------------------+------------+-------------+
283+
| timm-mobilenetv3\_small\_minimal\_100 | imagenet | 0.43M |
284+
+---------------------------------------+------------+-------------+
277285

278286
DPN
279287
~~~
@@ -316,22 +324,3 @@ VGG
316324
+-------------+------------+-------------+
317325
| vgg19\_bn | imagenet | 20M |
318326
+-------------+------------+-------------+
319-
320-
MobileNetV3
321-
~~~~~~~~~
322-
323-
+-----------------------------------+------------+-------------+
324-
| Encoder | Weights | Params, M |
325-
+===================================+============+=============+
326-
| timm-mobilenetv3_large_075 | imagenet | 1.78M |
327-
+-----------------------------------+------------+-------------+
328-
| timm-mobilenetv3_large_100 | imagenet | 2.97M |
329-
+-----------------------------------+------------+-------------+
330-
| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M |
331-
+-----------------------------------+------------+-------------+
332-
| timm-mobilenetv3_small_075 | imagenet | 0.57M |
333-
+-----------------------------------+------------+-------------+
334-
| timm-mobilenetv3_small_100 | imagenet | 0.93M |
335-
+-----------------------------------+------------+-------------+
336-
| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M |
337-
+-----------------------------------+------------+-------------+

‎docs/losses.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ DiceLoss
1717
~~~~~~~~
1818
.. autoclass:: segmentation_models_pytorch.losses.DiceLoss
1919

20+
TverskyLoss
21+
~~~~~~~~
22+
.. autoclass:: segmentation_models_pytorch.losses.TverskyLoss
23+
2024
FocalLoss
2125
~~~~~~~~~
2226
.. autoclass:: segmentation_models_pytorch.losses.FocalLoss

‎requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torchvision>=0.9.0
1+
torchvision>=0.5.0
22
pretrainedmodels==0.7.4
33
efficientnet-pytorch==0.6.3
44
timm==0.4.12
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
VERSION = (0, 1, 3)
1+
VERSION = (0, 2, 0)
22

33
__version__ = '.'.join(map(str, VERSION))

‎segmentation_models_pytorch/encoders/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,14 @@
1010
from .inceptionv4 import inceptionv4_encoders
1111
from .efficientnet import efficient_net_encoders
1212
from .mobilenet import mobilenet_encoders
13-
from .mobilenet_v3 import mobilenet_v3_encoders
1413
from .xception import xception_encoders
1514
from .timm_efficientnet import timm_efficientnet_encoders
1615
from .timm_resnest import timm_resnest_encoders
1716
from .timm_res2net import timm_res2net_encoders
1817
from .timm_regnet import timm_regnet_encoders
1918
from .timm_sknet import timm_sknet_encoders
2019
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
21-
try:
22-
from .timm_gernet import timm_gernet_encoders
23-
except ImportError as e:
24-
timm_gernet_encoders = {}
25-
print("Current timm version doesn't support GERNet."
26-
"If GERNet support is needed please update timm")
20+
from .timm_gernet import timm_gernet_encoders
2721

2822
from ._preprocessing import preprocess_input
2923

@@ -37,7 +31,6 @@
3731
encoders.update(inceptionv4_encoders)
3832
encoders.update(efficient_net_encoders)
3933
encoders.update(mobilenet_encoders)
40-
encoders.update(mobilenet_v3_encoders)
4134
encoders.update(xception_encoders)
4235
encoders.update(timm_efficientnet_encoders)
4336
encoders.update(timm_resnest_encoders)
@@ -68,7 +61,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None):
6861
))
6962
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
7063

71-
encoder.set_in_channels(in_channels)
64+
encoder.set_in_channels(in_channels, pretrained=weightsisnotNone)
7265

7366
return encoder
7467

‎segmentation_models_pytorch/encoders/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def out_channels(self):
1717
"""Return channels dimensions for each tensor of forward output of encoder"""
1818
return self._out_channels[: self._depth + 1]
1919

20-
def set_in_channels(self, in_channels):
20+
def set_in_channels(self, in_channels, pretrained=True):
2121
"""Change first convolution channels"""
2222
if in_channels == 3:
2323
return
@@ -26,7 +26,7 @@ def set_in_channels(self, in_channels):
2626
if self._out_channels[0] == 3:
2727
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
2828

29-
utils.patch_first_conv(model=self, in_channels=in_channels)
29+
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
3030

3131
def get_stages(self):
3232
"""Method should be overridden in encoder"""

‎segmentation_models_pytorch/encoders/_utils.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33

44

5-
def patch_first_conv(model, in_channels):
5+
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
66
"""Change first convolution layer input channels.
77
In case:
88
in_channels == 1 or in_channels == 2 -> reuse original weights
@@ -11,29 +11,38 @@ def patch_first_conv(model, in_channels):
1111

1212
# get first conv
1313
for module in model.modules():
14-
if isinstance(module, nn.Conv2d):
14+
if isinstance(module, nn.Conv2d)andmodule.in_channels==default_in_channels:
1515
break
16-
17-
# change input channels for first conv
18-
module.in_channels = in_channels
16+
1917
weight = module.weight.detach()
20-
reset = False
21-
22-
if in_channels == 1:
23-
weight = weight.sum(1, keepdim=True)
24-
elif in_channels == 2:
25-
weight = weight[:, :2] * (3.0 / 2.0)
18+
module.in_channels = new_in_channels
19+
20+
if not pretrained:
21+
module.weight = nn.parameter.Parameter(
22+
torch.Tensor(
23+
module.out_channels,
24+
new_in_channels // module.groups,
25+
*module.kernel_size
26+
)
27+
)
28+
module.reset_parameters()
29+
30+
elif new_in_channels == 1:
31+
new_weight = weight.sum(1, keepdim=True)
32+
module.weight = nn.parameter.Parameter(new_weight)
33+
2634
else:
27-
reset = True
28-
weight = torch.Tensor(
35+
new_weight = torch.Tensor(
2936
module.out_channels,
30-
module.in_channels // module.groups,
37+
new_in_channels // module.groups,
3138
*module.kernel_size
3239
)
3340

34-
module.weight = nn.parameter.Parameter(weight)
35-
if reset:
36-
module.reset_parameters()
41+
for i in range(new_in_channels):
42+
new_weight[:, i] = weight[:, i % default_in_channels]
43+
44+
new_weight = new_weight * (default_in_channels / new_in_channels)
45+
module.weight = nn.parameter.Parameter(new_weight)
3746

3847

3948
def replace_strides_with_dilation(module, dilation_rate):

‎segmentation_models_pytorch/encoders/densenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def load_state_dict(self, state_dict):
9696
del state_dict[key]
9797

9898
# remove linear
99-
state_dict.pop("classifier.bias")
100-
state_dict.pop("classifier.weight")
99+
state_dict.pop("classifier.bias", None)
100+
state_dict.pop("classifier.weight", None)
101101

102102
super().load_state_dict(state_dict)
103103

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /