Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4d20629

Browse files
ricberqubvel
andauthored
Add encoder_freeze method to SegmentationModel (#1220)
* Add encoder_freeze method * Add encoder freeze/unfreeze methods and override train() * Fix typo * Remove unnecessary print * Refactor _set_encoder_trainable Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Add tests for encoder freezing * Add assertion on running_var * Refactor test_freeze_and_unfreeze_encoder Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Refactor test and add call to eval * add example for encoder freezing --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent e76ed01 commit 4d20629

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

‎docs/insights.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,25 @@ Example:
117117
118118
mask.shape, label.shape
119119
# (N, 4, H, W), (N, 4)
120+
121+
4. Freezing and unfreezing the encoder
122+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
123+
124+
Sometimes you may want to freeze the encoder during training, e.g. when using pretrained backbones and only fine-tuning the decoder and segmentation head.
125+
126+
All segmentation models in SMP provide two helper methods:
127+
128+
.. code-block:: python
129+
130+
model = smp.Unet("resnet34", classes=2)
131+
132+
# Freeze encoder: stops gradient updates and freezes normalization layer stats
133+
model.freeze_encoder()
134+
135+
# Unfreeze encoder: re-enables training for encoder parameters and normalization layers
136+
model.unfreeze_encoder()
137+
138+
.. important::
139+
- Freezing sets ``requires_grad = False`` for all encoder parameters.
140+
- Normalization layers that track running statistics (e.g., BatchNorm and InstanceNorm layers) are set to ``.eval()`` mode to prevent updates to ``running_mean`` and ``running_var``.
141+
- If you later call ``model.train()``, frozen encoders will remain frozen until you call ``unfreeze_encoder()``.

‎segmentation_models_pytorch/base/model.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def __new__(cls: Type[T], *args, **kwargs) -> T:
2424
instance = super().__new__(cls, *args, **kwargs)
2525
return instance
2626

27+
def __init__(self):
28+
super().__init__()
29+
self._is_encoder_frozen = False
30+
2731
def initialize(self):
2832
init.initialize_decoder(self.decoder)
2933
init.initialize_head(self.segmentation_head)
@@ -137,3 +141,70 @@ def load_state_dict(self, state_dict, **kwargs):
137141
warnings.warn(text, stacklevel=-1)
138142

139143
return super().load_state_dict(state_dict, **kwargs)
144+
145+
def train(self, mode: bool = True):
146+
"""Set the module in training mode.
147+
148+
This method behaves like the standard :meth:`torch.nn.Module.train`,
149+
with one exception: if the encoder has been frozen via
150+
:meth:`freeze_encoder`, then its normalization layers are not affected
151+
by this call. In other words, calling ``model.train()`` will not
152+
re-enable updates to frozen encoder normalization layers
153+
(e.g., BatchNorm, InstanceNorm).
154+
155+
To restore the encoder to normal training behavior, use
156+
:meth:`unfreeze_encoder`.
157+
158+
Args:
159+
mode (bool): whether to set training mode (``True``) or evaluation
160+
mode (``False``). Default: ``True``.
161+
162+
Returns:
163+
Module: self
164+
"""
165+
if not isinstance(mode, bool):
166+
raise ValueError("training mode is expected to be boolean")
167+
self.training = mode
168+
for name, module in self.named_children():
169+
# skip encoder if it is frozen
170+
if self._is_encoder_frozen and name == "encoder":
171+
continue
172+
module.train(mode)
173+
return self
174+
175+
def _set_encoder_trainable(self, mode: bool):
176+
for param in self.encoder.parameters():
177+
param.requires_grad = mode
178+
179+
for module in self.encoder.modules():
180+
# _NormBase is the common base of classes like _InstanceNorm
181+
# and _BatchNorm that track running stats
182+
if isinstance(module, torch.nn.modules.batchnorm._NormBase):
183+
module.train(mode)
184+
185+
self._is_encoder_frozen = not mode
186+
187+
def freeze_encoder(self):
188+
"""
189+
Freeze encoder parameters and disable updates to normalization
190+
layer statistics.
191+
192+
This method:
193+
- Sets ``requires_grad = False`` for all encoder parameters,
194+
preventing them from being updated during backpropagation.
195+
- Puts normalization layers that track running statistics
196+
(e.g., BatchNorm, InstanceNorm) into evaluation mode (``.eval()``),
197+
so their ``running_mean`` and ``running_var`` are no longer updated.
198+
"""
199+
return self._set_encoder_trainable(False)
200+
201+
def unfreeze_encoder(self):
202+
"""
203+
Unfreeze encoder parameters and restore normalization layers to training mode.
204+
205+
This method reverts the effect of :meth:`freeze_encoder`. Specifically:
206+
- Sets ``requires_grad=True`` for all encoder parameters.
207+
- Restores normalization layers (e.g. BatchNorm, InstanceNorm) to training mode,
208+
so their running statistics are updated again.
209+
"""
210+
return self._set_encoder_trainable(True)

‎tests/base/test_freeze_encoder.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import segmentation_models_pytorch as smp
3+
4+
5+
def test_freeze_and_unfreeze_encoder():
6+
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
7+
8+
def assert_encoder_params_trainable(expected: bool):
9+
assert all(p.requires_grad == expected for p in model.encoder.parameters())
10+
11+
def assert_norm_layers_training(expected: bool):
12+
for m in model.encoder.modules():
13+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
14+
assert m.training == expected
15+
16+
# Initially, encoder params should be trainable
17+
model.train()
18+
assert_encoder_params_trainable(True)
19+
20+
# Freeze encoder
21+
model.freeze_encoder()
22+
assert_encoder_params_trainable(False)
23+
assert_norm_layers_training(False)
24+
25+
# Call train() and ensure encoder norm layers stay frozen
26+
model.train()
27+
assert_norm_layers_training(False)
28+
29+
# Unfreeze encoder
30+
model.unfreeze_encoder()
31+
assert_encoder_params_trainable(True)
32+
assert_norm_layers_training(True)
33+
34+
# Call train() again — should stay trainable
35+
model.train()
36+
assert_norm_layers_training(True)
37+
38+
# Switch to eval, then freeze
39+
model.eval()
40+
model.freeze_encoder()
41+
assert_encoder_params_trainable(False)
42+
assert_norm_layers_training(False)
43+
44+
# Unfreeze again
45+
model.unfreeze_encoder()
46+
assert_encoder_params_trainable(True)
47+
assert_norm_layers_training(True)
48+
49+
50+
def test_freeze_encoder_stops_running_stats():
51+
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
52+
model.freeze_encoder()
53+
model.train() # overridden train, encoder should remain frozen
54+
bn = None
55+
56+
for m in model.encoder.modules():
57+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
58+
bn = m
59+
break
60+
61+
assert bn is not None
62+
63+
orig_mean = bn.running_mean.clone()
64+
orig_var = bn.running_var.clone()
65+
66+
x = torch.randn(2, 3, 64, 64)
67+
_ = model(x)
68+
69+
torch.testing.assert_close(orig_mean, bn.running_mean)
70+
torch.testing.assert_close(orig_var, bn.running_var)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /