Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c23bbb0

Browse files
authored
Update HF mixin (#910)
* Update mixin * Add reqs for hub lib * Add example to save load share * Add filter warning (not relevant) * Fix typo
1 parent f40b6ed commit c23bbb0

File tree

4 files changed

+282
-42
lines changed

4 files changed

+282
-42
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import segmentation_models_pytorch as smp"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"## Save to local directory and load back"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 2,
22+
"metadata": {},
23+
"outputs": [
24+
{
25+
"name": "stdout",
26+
"output_type": "stream",
27+
"text": [
28+
"Loading weights from local directory\n"
29+
]
30+
}
31+
],
32+
"source": [
33+
"model = smp.Unet()\n",
34+
"\n",
35+
"# save the model\n",
36+
"model.save_pretrained(\"saved-model-dir/unet/\")\n",
37+
"\n",
38+
"# load the model\n",
39+
"restored_model = smp.from_pretrained(\"saved-model-dir/unet/\")"
40+
]
41+
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"## Save model with additional metadata"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 6,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"model = smp.Unet()\n",
56+
"\n",
57+
"# save the model\n",
58+
"model.save_pretrained(\n",
59+
" \"saved-model-dir/unet-with-metadata/\",\n",
60+
"\n",
61+
" # additional information to be saved with the model\n",
62+
" # only \"dataset\" and \"metrics\" are supported\n",
63+
" dataset=\"PASCAL VOC\", # only string name is supported\n",
64+
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
65+
" \"mIoU\": 0.95,\n",
66+
" \"accuracy\": 0.96\n",
67+
" }\n",
68+
")"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": 7,
74+
"metadata": {},
75+
"outputs": [
76+
{
77+
"name": "stdout",
78+
"output_type": "stream",
79+
"text": [
80+
"---\n",
81+
"library_name: segmentation-models-pytorch\n",
82+
"license: mit\n",
83+
"pipeline_tag: image-segmentation\n",
84+
"tags:\n",
85+
"- semantic-segmentation\n",
86+
"- pytorch\n",
87+
"- segmentation-models-pytorch\n",
88+
"languages:\n",
89+
"- python\n",
90+
"---\n",
91+
"# Unet Model Card\n",
92+
"\n",
93+
"Table of Contents:\n",
94+
"- [Load trained model](#load-trained-model)\n",
95+
"- [Model init parameters](#model-init-parameters)\n",
96+
"- [Model metrics](#model-metrics)\n",
97+
"- [Dataset](#dataset)\n",
98+
"\n",
99+
"## Load trained model\n",
100+
"```python\n",
101+
"import segmentation_models_pytorch as smp\n",
102+
"\n",
103+
"model = smp.from_pretrained(\"<save-directory-or-this-repo>\")\n",
104+
"```\n",
105+
"\n",
106+
"## Model init parameters\n",
107+
"```python\n",
108+
"model_init_params = {\n",
109+
" \"encoder_name\": \"resnet34\",\n",
110+
" \"encoder_depth\": 5,\n",
111+
" \"encoder_weights\": \"imagenet\",\n",
112+
" \"decoder_use_batchnorm\": True,\n",
113+
" \"decoder_channels\": (256, 128, 64, 32, 16),\n",
114+
" \"decoder_attention_type\": None,\n",
115+
" \"in_channels\": 3,\n",
116+
" \"classes\": 1,\n",
117+
" \"activation\": None,\n",
118+
" \"aux_params\": None\n",
119+
"}\n",
120+
"```\n",
121+
"\n",
122+
"## Model metrics\n",
123+
"```json\n",
124+
"{\n",
125+
" \"mIoU\": 0.95,\n",
126+
" \"accuracy\": 0.96\n",
127+
"}\n",
128+
"```\n",
129+
"\n",
130+
"## Dataset\n",
131+
"Dataset name: PASCAL VOC\n",
132+
"\n",
133+
"## More Information\n",
134+
"- Library: https://github.com/qubvel/segmentation_models.pytorch\n",
135+
"- Docs: https://smp.readthedocs.io/en/latest/\n",
136+
"\n",
137+
"This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)"
138+
]
139+
}
140+
],
141+
"source": [
142+
"!cat \"saved-model-dir/unet-with-metadata/README.md\""
143+
]
144+
},
145+
{
146+
"cell_type": "markdown",
147+
"metadata": {},
148+
"source": [
149+
"## Share model with HF Hub"
150+
]
151+
},
152+
{
153+
"cell_type": "code",
154+
"execution_count": 5,
155+
"metadata": {},
156+
"outputs": [
157+
{
158+
"data": {
159+
"application/vnd.jupyter.widget-view+json": {
160+
"model_id": "075ae026811542bdb4030e53b943efc7",
161+
"version_major": 2,
162+
"version_minor": 0
163+
},
164+
"text/plain": [
165+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv..."
166+
]
167+
},
168+
"metadata": {},
169+
"output_type": "display_data"
170+
}
171+
],
172+
"source": [
173+
"from huggingface_hub import notebook_login\n",
174+
"\n",
175+
"# You only need to run this once on the machine,\n",
176+
"# the token will be stored for later use\n",
177+
"notebook_login()"
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": 8,
183+
"metadata": {},
184+
"outputs": [
185+
{
186+
"data": {
187+
"application/vnd.jupyter.widget-view+json": {
188+
"model_id": "2921a81d7fd747939b4a425cc17d6104",
189+
"version_major": 2,
190+
"version_minor": 0
191+
},
192+
"text/plain": [
193+
"model.safetensors: 0%| | 0.00/97.8M [00:00<?, ?B/s]"
194+
]
195+
},
196+
"metadata": {},
197+
"output_type": "display_data"
198+
},
199+
{
200+
"data": {
201+
"text/plain": [
202+
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)"
203+
]
204+
},
205+
"execution_count": 8,
206+
"metadata": {},
207+
"output_type": "execute_result"
208+
}
209+
],
210+
"source": [
211+
"model = smp.Unet()\n",
212+
"\n",
213+
"# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
214+
"model.save_pretrained(\n",
215+
" \"qubvel-hf/unet-with-metadata/\",\n",
216+
" push_to_hub=True, # <---------- push the model to the hub\n",
217+
" private=False, # <---------- make the model private or or public\n",
218+
" dataset=\"PASCAL VOC\",\n",
219+
" metrics={\n",
220+
" \"mIoU\": 0.95,\n",
221+
" \"accuracy\": 0.96\n",
222+
" }\n",
223+
")\n",
224+
"\n",
225+
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
226+
]
227+
}
228+
],
229+
"metadata": {
230+
"kernelspec": {
231+
"display_name": ".venv",
232+
"language": "python",
233+
"name": "python3"
234+
},
235+
"language_info": {
236+
"codemirror_mode": {
237+
"name": "ipython",
238+
"version": 3
239+
},
240+
"file_extension": ".py",
241+
"mimetype": "text/x-python",
242+
"name": "python",
243+
"nbconvert_exporter": "python",
244+
"pygments_lexer": "ipython3",
245+
"version": "3.10.12"
246+
}
247+
},
248+
"nbformat": 4,
249+
"nbformat_minor": 2
250+
}

‎requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ torchvision>=0.5.0
22
pretrainedmodels==0.7.4
33
efficientnet-pytorch==0.7.1
44
timm==0.9.7
5+
huggingface_hub>=0.24.6
56

67
tqdm
78
pillow

‎segmentation_models_pytorch/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from . import datasets
24
from . import encoders
35
from . import decoders
@@ -20,6 +22,9 @@
2022
from typing import Optional as _Optional
2123
import torch as _torch
2224

25+
# Suppress the specific SyntaxWarning for `pretrainedmodels`
26+
warnings.filterwarnings("ignore", message="is with a literal", category=SyntaxWarning)
27+
2328

2429
def create_model(
2530
arch: str,

‎segmentation_models_pytorch/base/hub_mixin.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
```python
2727
import segmentation_models_pytorch as smp
2828
29-
model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}")
29+
model = smp.from_pretrained("<save-directory-or-this-repo>")
3030
```
3131
3232
## Model init parameters
@@ -61,23 +61,22 @@ def _format_parameters(parameters: dict):
6161

6262
class SMPHubMixin(PyTorchModelHubMixin):
6363
def generate_model_card(self, *args, **kwargs) -> ModelCard:
64-
model_parameters_json = _format_parameters(self._hub_mixin_config)
65-
directory = self._save_directory if hasattr(self, "_save_directory") else None
66-
repo_id = self._repo_id if hasattr(self, "_repo_id") else None
67-
repo_or_directory = repo_id if repo_id is not None else directory
68-
69-
metrics = self._metrics if hasattr(self, "_metrics") else None
70-
dataset = self._dataset if hasattr(self, "_dataset") else None
64+
model_parameters_json = _format_parameters(self.config)
65+
metrics = kwargs.get("metrics", None)
66+
dataset = kwargs.get("dataset", None)
7167

7268
if metrics is not None:
7369
metrics = json.dumps(metrics, indent=4)
7470
metrics = f"```json\n{metrics}\n```"
7571

72+
tags = self._hub_mixin_info.model_card_data.get("tags", []) or []
73+
tags.extend(["segmentation-models-pytorch", "semantic-segmentation", "pytorch"])
74+
7675
model_card_data = ModelCardData(
7776
languages=["python"],
7877
library_name="segmentation-models-pytorch",
7978
license="mit",
80-
tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"],
79+
tags=tags,
8180
pipeline_tag="image-segmentation",
8281
)
8382
model_card = ModelCard.from_template(
@@ -86,64 +85,49 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
8685
repo_url="https://github.com/qubvel/segmentation_models.pytorch",
8786
docs_url="https://smp.readthedocs.io/en/latest/",
8887
model_parameters=model_parameters_json,
89-
save_directory=repo_or_directory,
9088
model_name=self.__class__.__name__,
9189
metrics=metrics,
9290
dataset=dataset,
9391
)
9492
return model_card
9593

96-
def _set_attrs_from_kwargs(self, attrs, kwargs):
97-
for attr in attrs:
98-
if attr in kwargs:
99-
setattr(self, f"_{attr}", kwargs.pop(attr))
100-
101-
def _del_attrs(self, attrs):
102-
for attr in attrs:
103-
if hasattr(self, f"_{attr}"):
104-
delattr(self, f"_{attr}")
105-
10694
@wraps(PyTorchModelHubMixin.save_pretrained)
10795
def save_pretrained(
10896
self, save_directory: Union[str, Path], *args, **kwargs
10997
) -> Optional[str]:
110-
# set additional attributes to be used in generate_model_card
111-
self._save_directory = save_directory
112-
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
98+
model_card_kwargs = kwargs.pop("model_card_kwargs", {})
99+
if "dataset" in kwargs:
100+
model_card_kwargs["dataset"] = kwargs.pop("dataset")
101+
if "metrics" in kwargs:
102+
model_card_kwargs["metrics"] = kwargs.pop("metrics")
103+
kwargs["model_card_kwargs"] = model_card_kwargs
113104

114-
# set additional attribute to be used in from_pretrained
115-
self._hub_mixin_config["_model_diff-ddc564a4f580a1a0e674d7b0312555d2c3a29d2a025e53de685c156a0bf7f745-115-105-0" data-selected="false" role="gridcell" tabindex="-1">
105+
# set additional attribute to be able to deserialize the model
106+
self.config["_model_class"] = self.__class__.__name__
116107

117108
try:
118109
# call the original save_pretrained
119110
result = super().save_pretrained(save_directory, *args, **kwargs)
120111
finally:
121-
# delete the additional attributes
122-
self._del_attrs(["save_directory", "metrics", "dataset"])
123-
self._hub_mixin_config.pop("_model_class", None)
112+
self.config.pop("_model_class", None)
124113

125114
return result
126115

127-
@wraps(PyTorchModelHubMixin.push_to_hub)
128-
def push_to_hub(self, repo_id: str, *args, **kwargs):
129-
self._repo_id = repo_id
130-
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
131-
result = super().push_to_hub(repo_id, *args, **kwargs)
132-
self._del_attrs(["repo_id", "metrics", "dataset"])
133-
return result
134-
135116
@property
136-
def config(self):
117+
def config(self)->dict:
137118
return self._hub_mixin_config
138119

139120

140121
@wraps(PyTorchModelHubMixin.from_pretrained)
141122
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
142-
config_path = hf_hub_download(
143-
pretrained_model_name_or_path,
144-
filename="config.json",
145-
revision=kwargs.get("revision", None),
146-
)
123+
config_path = Path(pretrained_model_name_or_path) / "config.json"
124+
if not config_path.exists():
125+
config_path = hf_hub_download(
126+
pretrained_model_name_or_path,
127+
filename="config.json",
128+
revision=kwargs.get("revision", None),
129+
)
130+
147131
with open(config_path, "r") as f:
148132
config = json.load(f)
149133
model_class_name = config.pop("_model_class")

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /