20
20
if is_torch_available ():
21
21
import torch
22
22
23
- if is_accelerate_available ():
24
- pass
25
-
26
- if is_nunchaku_available ():
27
- from .utils import replace_with_nunchaku_linear
28
23
29
24
logger = logging .get_logger (__name__ )
30
25
@@ -79,13 +74,14 @@ def check_if_quantized_param(
79
74
state_dict : Dict [str , Any ],
80
75
** kwargs ,
81
76
):
82
- from nunchaku .models .linear import SVDQW4A4Linear
83
-
84
- module , tensor_name = get_module_from_name (model , param_name )
85
- if self .pre_quantized and isinstance (module , SVDQW4A4Linear ):
86
- return True
87
-
88
- return False
77
+ # TODO: revisit
78
+ # Check if the param_name is not in self.modules_to_not_convert
79
+ if any ((key + "." in param_name ) or (key == param_name ) for key in self .modules_to_not_convert ):
80
+ return False
81
+ else :
82
+ # We only quantize the weight of nn.Linear
83
+ module , _ = get_module_from_name (model , param_name )
84
+ return isinstance (module , torch .nn .Linear )
89
85
90
86
def create_quantized_param (
91
87
self ,
@@ -112,13 +108,32 @@ def create_quantized_param(
112
108
module ._buffers [tensor_name ] = torch .nn .Parameter (param_value .to (target_device ))
113
109
114
110
elif isinstance (module , torch .nn .Linear ):
115
- if tensor_name in module ._parameters :
116
- module ._parameters [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
117
- if tensor_name in module ._buffers :
118
- module ._buffers [tensor_name ] = torch .nn .Parameter (param_value ).to (target_device )
119
-
120
- new_module = SVDQW4A4Linear .from_linear (module )
121
- setattr (model , param_name , new_module )
111
+ # TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
112
+ # But we need to have a utility that can take a pretrained param value and quantize it. Not sure
113
+ # how to do that yet.
114
+ # Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
115
+ # way to do it?
116
+ is_param = tensor_name in module ._parameters
117
+ is_buffer = tensor_name in module ._buffers
118
+ new_module = SVDQW4A4Linear .from_linear (
119
+ module , precision = self .quantization_config .precision , rank = self .quantization_config .rank
120
+ )
121
+ module_name = "." .join (param_name .split ("." )[:- 1 ])
122
+ if "." in module_name :
123
+ parent_name , leaf = module_name .rsplit ("." , 1 )
124
+ parent = model .get_submodule (parent_name )
125
+ else :
126
+ parent , leaf = model , module_name
127
+
128
+ # rebind
129
+ # this will result into
130
+ # AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
131
+ if is_param :
132
+ new_module ._parameters [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
133
+ elif is_buffer :
134
+ new_module ._buffers [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
135
+
136
+ setattr (parent , leaf , new_module )
122
137
123
138
def adjust_max_memory (self , max_memory : Dict [str , Union [int , str ]]) -> Dict [str , Union [int , str ]]:
124
139
max_memory = {key : val * 0.90 for key , val in max_memory .items ()}
@@ -157,24 +172,25 @@ def _process_model_before_weight_loading(
157
172
keep_in_fp32_modules : List [str ] = [],
158
173
** kwargs ,
159
174
):
160
- # TODO: deal with `device_map`
161
175
self .modules_to_not_convert = self .quantization_config .modules_to_not_convert
162
176
163
177
if not isinstance (self .modules_to_not_convert , list ):
164
178
self .modules_to_not_convert = [self .modules_to_not_convert ]
165
179
166
180
self .modules_to_not_convert .extend (keep_in_fp32_modules )
181
+
182
+ # TODO: revisit
183
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
184
+ # if isinstance(device_map, dict) and len(device_map.keys()) > 1:
185
+ # keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
186
+ # self.modules_to_not_convert.extend(keys_on_cpu)
187
+
167
188
# Purge `None`.
168
189
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
169
190
# in case of diffusion transformer models. For language models and others alike, `lm_head`
170
191
# and tied modules are usually kept in FP32.
171
192
self .modules_to_not_convert = [module for module in self .modules_to_not_convert if module is not None ]
172
193
173
- model = replace_with_nunchaku_linear (
174
- model ,
175
- modules_to_not_convert = self .modules_to_not_convert ,
176
- quantization_config = self .quantization_config ,
177
- )
178
194
model .config .quantization_config = self .quantization_config
179
195
180
196
def _process_model_after_weight_loading (self , model , ** kwargs ):
0 commit comments