99
1010from src .utils .fast_init import fast_init
1111from src .utils .logging import format_ms , log_rank_n
12+ from src .utils .utils import parse_revision
1213from transformers import (
1314 CONFIG_MAPPING ,
1415 AutoConfig ,
@@ -41,12 +42,14 @@ def __init__(
4142 self ,
4243 * ,
4344 model_type : Optional [str ] = None ,
45+ pretrained_config : Optional [str ] = None ,
4446 pretrained_model : Optional [str ] = None ,
4547 config_args : Dict [str , Any ],
4648 tokenizer : str ,
4749 device : torch .device ,
4850 dtype : torch .dtype ,
4951 fast_init : bool = True ,
52+ trust_remote_code : bool = False ,
5053 ):
5154 self .initialization_metrics = {}
5255 log_rank_n ("*** Setting up tokenizer" , logger .info )
@@ -60,10 +63,11 @@ def __init__(
6063 self .dtype = dtype
6164 self .is_int8 = self .dtype == torch .int8
6265 self .fast_init = fast_init
66+ self .trust_remote_code = trust_remote_code
6367 if self .is_int8 and self .device != torch .device ("cuda" ):
6468 raise ValueError (f"Model quantization not supported on device { self .device } " )
6569
66- self .config = self ._get_config (model_type , pretrained_model , config_args )
70+ self .config = self ._get_config (model_type , pretrained_config or pretrained_model , config_args )
6771 t2 = time .perf_counter ()
6872
6973 logger .info (f"Model configuration: { self .config } " )
@@ -86,7 +90,9 @@ def _create_model(self) -> PreTrainedModel:
8690 log_rank_n ("*** Creating model" , logger .info )
8791 with fast_init (self .device ) if self .fast_init else contextlib .nullcontext ():
8892 torch_dtype = torch .float16 if self .is_int8 else self .dtype
89- model = AutoModelForCausalLM .from_config (config = self .config , torch_dtype = torch_dtype )
93+ model = AutoModelForCausalLM .from_config (
94+ config = self .config , torch_dtype = torch_dtype , trust_remote_code = self .trust_remote_code
95+ )
9096 t1 = time .perf_counter ()
9197 log_rank_n ("*** Moving to device" , logger .info )
9298 model .to (self .device )
@@ -98,6 +104,7 @@ def _create_model(self) -> PreTrainedModel:
98104 self .initialization_metrics ["model initialization" ] = t1 - t0
99105 self .initialization_metrics ["move to device" ] = t2 - t1
100106 self .initialization_metrics ["initialize weights" ] = t3 - t2
107+ 101108 return model
102109
103110 def _reload_model (self ):
@@ -118,9 +125,12 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
118125 log_rank_n (f"*** Loading model from { pretrained_model } " , logger .info )
119126 kwargs = {"load_in_8bit" : True , "device_map" : "auto" } if self .is_int8 else {"torch_dtype" : self .dtype }
120127 with fast_init (self .device ) if self .fast_init else contextlib .nullcontext ():
128+ pretrained_model , revision = parse_revision (pretrained_model )
121129 model = AutoModelForCausalLM .from_pretrained (
122130 pretrained_model ,
131+ revision = revision ,
123132 config = self .config ,
133+ trust_remote_code = self .trust_remote_code ,
124134 ** kwargs ,
125135 )
126136 t1 = time .perf_counter ()
@@ -135,7 +145,7 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
135145 def _get_config (
136146 self ,
137147 model_type : Optional [str ],
138- pretrained_model : Optional [str ],
148+ pretrained_config : Optional [str ],
139149 config_args : Dict [str , Any ],
140150 ) -> PretrainedConfig :
141151 config_args = {
@@ -145,15 +155,16 @@ def _get_config(
145155 }
146156
147157 if model_type is None :
148- if pretrained_model is None :
158+ if pretrained_config is None :
149159 raise ValueError ("You need to provide either --model_type or --pretrained_model" )
150160 config_class = AutoConfig
151161 elif model_type not in CONFIG_MAPPING :
152162 raise ValueError (f"Unknown model type: { model_type } " )
153163 else :
154164 config_class = CONFIG_MAPPING [model_type ]
165+ config_args ["model_type" ] = model_type
155166
156- if pretrained_model is None :
167+ if pretrained_config is None :
157168 config_args .update (
158169 {
159170 "bos_token_id" : self .tokenizer .bos_token_id ,
@@ -163,7 +174,10 @@ def _get_config(
163174 )
164175 config , unused = config_class .from_dict ({}, ** config_args )
165176 else :
166- config , unused = config_class .from_pretrained (pretrained_model , ** config_args )
177+ pretrained_config , revision = parse_revision (pretrained_config )
178+ config , unused = config_class .from_pretrained (
179+ pretrained_config , revision = revision , trust_remote_code = self .trust_remote_code , ** config_args
180+ )
167181
168182 if unused :
169183 raise ValueError (f"There were unused configuration parameters: { tuple (unused )} " )
@@ -216,7 +230,8 @@ def aggregate_and_format_metrics(self, metrics: List[Dict[str, Any]]):
216230 "Latency (decode)" : format_ms (mean_metrics [DECODE_TIME ]),
217231 "Latency (max)" : format_ms (max (all_metrics [END_TO_END_TIME ])),
218232 "Latency (min)" : format_ms (min (all_metrics [END_TO_END_TIME ])),
219- "Tokens generated" : f"{ mean_metrics [NUM_GENERATED_TOKENS ]:.0f} " ,
233+ "Tokens generated (average)" : f"{ mean_metrics [NUM_GENERATED_TOKENS ]:.0f} " ,
234+ "Tokens generated (total)" : f"{ np .sum (all_metrics [NUM_GENERATED_TOKENS ]).item ():.0f} " ,
220235 "Throughput (model)" : f"{ model_throughput :.2f} tokens/s" ,
221236 "Throughput (end to end)" : f"{ throughput :.2f} tokens/s" ,
222237 "Token time (end to end)" : f"{ format_ms (throughput ** - 1 )} /token" ,
0 commit comments