Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit efb24b2

Browse files
Add _post_quantize to fix torch_params not releasing issue. (#21654)
1 parent 11da67d commit efb24b2

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

‎keras/src/models/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ def quantize(self, mode, config=None, **kwargs):
442442
"`keras.quantizers.GPTQConfig`."
443443
)
444444
gptq_quantize(self, config)
445+
self._post_quantize(mode, **kwargs)
445446
return
446447

447448
# For all other modes, verify that a config object was not passed.
@@ -477,6 +478,15 @@ def quantize(self, mode, config=None, **kwargs):
477478
self.train_function = None
478479
self.test_function = None
479480
self.predict_function = None
481+
self._post_quantize(mode, **kwargs)
482+
483+
def _post_quantize(self, mode, **kwargs):
484+
if backend.backend() == "torch":
485+
# We need to manually retrack `torch_params`.
486+
# The reason is that after quantization, the removed variables are
487+
# still referenced by `torch_params` and cannot be gc.
488+
for layer in self._flatten_layers():
489+
layer._track_variables()
480490

481491
def build_from_config(self, config):
482492
if not config:

‎keras/src/models/model_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,14 @@ def test_quantize(self, mode):
808808
layer.dtype_policy.name, f"{mode}_from_float32"
809809
)
810810
self.assertEqual(layer.dtype_policy.quantization_mode, mode)
811+
if mode == "int8":
812+
self.assertLen(model.variables, 6)
813+
if backend.backend() == "torch":
814+
self.assertLen(list(model.named_parameters()), 6)
815+
elif mode == "float8":
816+
self.assertLen(model.variables, 16)
817+
if backend.backend() == "torch":
818+
self.assertLen(list(model.named_parameters()), 16)
811819

812820
@parameterized.named_parameters(
813821
("int8", "int8"),

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /