Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit bdc58db

Browse files
committed
Fix typehints, improve error raising for encoders
1 parent aa705da commit bdc58db

File tree

4 files changed

+15
-5
lines changed

4 files changed

+15
-5
lines changed

‎segmentation_models_pytorch/encoders/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,23 @@
3737

3838

3939
def get_encoder(name, in_channels=3, depth=5, weights=None):
40-
Encoder = encoders[name]["encoder"]
40+
41+
try:
42+
Encoder = encoders[name]["encoder"]
43+
except KeyError:
44+
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
45+
4146
params = encoders[name]["params"]
4247
params.update(depth=depth)
4348
encoder = Encoder(**params)
4449

4550
if weights is not None:
46-
settings = encoders[name]["pretrained_settings"][weights]
51+
try:
52+
settings = encoders[name]["pretrained_settings"][weights]
53+
except KeyError:
54+
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Avaliable options are: {}".format(
55+
weights, name, list(encoders[name]["pretrained_settings"].keys()),
56+
))
4757
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
4858

4959
encoder.set_in_channels(in_channels)

‎segmentation_models_pytorch/pan/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class PAN(SegmentationModel):
4444
def __init__(
4545
self,
4646
encoder_name: str = "resnet34",
47-
encoder_weights: str = "imagenet",
47+
encoder_weights: Optional[str] = "imagenet",
4848
encoder_dilation: bool = True,
4949
decoder_channels: int = 32,
5050
in_channels: int = 3,

‎segmentation_models_pytorch/unet/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
self,
5252
encoder_name: str = "resnet34",
5353
encoder_depth: int = 5,
54-
encoder_weights: str = "imagenet",
54+
encoder_weights: Optional[str] = "imagenet",
5555
decoder_use_batchnorm: bool = True,
5656
decoder_channels: List[int] = (256, 128, 64, 32, 16),
5757
decoder_attention_type: Optional[str] = None,

‎segmentation_models_pytorch/unetplusplus/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
self,
5252
encoder_name: str = "resnet34",
5353
encoder_depth: int = 5,
54-
encoder_weights: str = "imagenet",
54+
encoder_weights: Optional[str] = "imagenet",
5555
decoder_use_batchnorm: bool = True,
5656
decoder_channels: List[int] = (256, 128, 64, 32, 16),
5757
decoder_attention_type: Optional[str] = None,

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /