Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit df58c80

Browse files
up
1 parent 2a827ec commit df58c80

File tree

2 files changed

+41
-106
lines changed

2 files changed

+41
-106
lines changed

‎src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020
if is_torch_available():
2121
import torch
2222

23-
if is_accelerate_available():
24-
pass
25-
26-
if is_nunchaku_available():
27-
from .utils import replace_with_nunchaku_linear
2823

2924
logger = logging.get_logger(__name__)
3025

@@ -79,13 +74,14 @@ def check_if_quantized_param(
7974
state_dict: Dict[str, Any],
8075
**kwargs,
8176
):
82-
from nunchaku.models.linear import SVDQW4A4Linear
83-
84-
module, tensor_name = get_module_from_name(model, param_name)
85-
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
86-
return True
87-
88-
return False
77+
# TODO: revisit
78+
# Check if the param_name is not in self.modules_to_not_convert
79+
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
80+
return False
81+
else:
82+
# We only quantize the weight of nn.Linear
83+
module, _ = get_module_from_name(model, param_name)
84+
return isinstance(module, torch.nn.Linear)
8985

9086
def create_quantized_param(
9187
self,
@@ -112,13 +108,32 @@ def create_quantized_param(
112108
module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device))
113109

114110
elif isinstance(module, torch.nn.Linear):
115-
if tensor_name in module._parameters:
116-
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
117-
if tensor_name in module._buffers:
118-
module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device)
119-
120-
new_module = SVDQW4A4Linear.from_linear(module)
121-
setattr(model, param_name, new_module)
111+
# TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
112+
# But we need to have a utility that can take a pretrained param value and quantize it. Not sure
113+
# how to do that yet.
114+
# Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
115+
# way to do it?
116+
is_param = tensor_name in module._parameters
117+
is_buffer = tensor_name in module._buffers
118+
new_module = SVDQW4A4Linear.from_linear(
119+
module, precision=self.quantization_config.precision, rank=self.quantization_config.rank
120+
)
121+
module_name = ".".join(param_name.split(".")[:-1])
122+
if "." in module_name:
123+
parent_name, leaf = module_name.rsplit(".", 1)
124+
parent = model.get_submodule(parent_name)
125+
else:
126+
parent, leaf = model, module_name
127+
128+
# rebind
129+
# this will result into
130+
# AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
131+
if is_param:
132+
new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
133+
elif is_buffer:
134+
new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
135+
136+
setattr(parent, leaf, new_module)
122137

123138
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
124139
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
@@ -157,24 +172,25 @@ def _process_model_before_weight_loading(
157172
keep_in_fp32_modules: List[str] = [],
158173
**kwargs,
159174
):
160-
# TODO: deal with `device_map`
161175
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
162176

163177
if not isinstance(self.modules_to_not_convert, list):
164178
self.modules_to_not_convert = [self.modules_to_not_convert]
165179

166180
self.modules_to_not_convert.extend(keep_in_fp32_modules)
181+
182+
# TODO: revisit
183+
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
184+
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
185+
# keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
186+
# self.modules_to_not_convert.extend(keys_on_cpu)
187+
167188
# Purge `None`.
168189
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
169190
# in case of diffusion transformer models. For language models and others alike, `lm_head`
170191
# and tied modules are usually kept in FP32.
171192
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
172193

173-
model = replace_with_nunchaku_linear(
174-
model,
175-
modules_to_not_convert=self.modules_to_not_convert,
176-
quantization_config=self.quantization_config,
177-
)
178194
model.config.quantization_config = self.quantization_config
179195

180196
def _process_model_after_weight_loading(self, model, **kwargs):

‎src/diffusers/quantizers/nunchaku/utils.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /