diff --git a/README.md b/README.md index df32214..468cc88 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,10 @@ ## News +🔥🔥🔥 [2024年10月31日] We released **MFTCoder v0.5** mainly for MFTCoder-accelerate, which is now supporting preference alignment methods like **DPO/RPO/ORPO** in the new **xxpo** module, adding full-parameter continue-training in the additional **mpt** module along with its **offline_tokenization** module, updating selfpaced method to new convergence balance(CoBa) method for MFT in the original **pefts** module. + +🔥🔥🔥 [2024年10月31日] Our paper [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) has been accepted by EMNLP-2024, which achieves balanced convergence across various tasks. + 🔥🔥🔥 [2024年05月20日] We released **MFTCoder v0.4**, mainly for MFTCoder-accelerate. It supports **QLoRA + DeepSpeed Zero3** and **QLoRA + FSDP** as options allowing you training very large models. It now supports new models like Qwen2, Qwen2-MoE, Starcoder2, Gemma, etc. 🔥🔥🔥 [2024年05月20日] Our paper [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) has been accepted by KDD2024. diff --git a/README_cn.md b/README_cn.md index 6b31613..3102d9f 100644 --- a/README_cn.md +++ b/README_cn.md @@ -45,6 +45,10 @@ ## 新闻 +🔥🔥🔥 [2024年10月31日] **MFTCoder-v0.5**发布,新增**xxpo**模块支持偏好对齐DPO/RPO/ORPO;新增**mpt**和**offline_tokenization**模块支持全量参数的加训;在原本的**pefts**模块(MFT)更新selfpaced收敛均衡技术并更名CoBa。 + +🔥🔥🔥 [2024年10月31日] 我们的论文 [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) 已被 EMNLP 2024 接收,可以实现多任务收敛均衡。 + 🔥🔥🔥 [2024年05月20日] **MFTCoder-v0.4**发布。新增支持**QLoRA+ DeepSpeed Zero3**, **QLoRA + FSDP**训练模式,可以更好的支持微调更大的模型,比如Qwen1.5-70B等。新增对Qwen2, Qwen2-MoE, Starcoder2, Gemma等模型的支持。 🔥🔥🔥 [2024年05月20日] 我们的论文 [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) 已被 KDD 2024 接收. diff --git a/mftcoder_accelerate/README.md b/mftcoder_accelerate/README.md index f65a21e..87b4b63 100644 --- a/mftcoder_accelerate/README.md +++ b/mftcoder_accelerate/README.md @@ -7,22 +7,28 @@ [[中文]](README_cn.md) [**English**] ## 1. Updates +🔥 MFTCoder-accelerate now supports DPO/ORPO training through xxpo module. + +🔥 MFTCoder-accelerate now supports continue training through mpt module along with offline_tokenization module. + +🔥 MFTCoder-accelerate supports MFT with latest implementation of CoBa Loss (selfpaced Loss) for better Convergence Balance. + 🔥 MFTCoder-accelerate now support these modes: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, Full-parameter + DeepSpeed ZeRO3, QLoRA + FSDP, Full-parameter + FSDP. -🔥 MFTCoder-accelerate supports QLoRA + DeepSpeed ZeRO3 and QLoRA + FSDP, which both work for larger models; +🔥 MFTCoder-accelerate supports QLoRA + DeepSpeed ZeRO3 and QLoRA + FSDP, which both work for larger models. -🔥 MFTCoder-accelerate supports MFT/SFT on more new mainstream open-source base models: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3; +🔥 MFTCoder-accelerate supports MFT/SFT on more new mainstream open-source base models: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3. -🔥 MFTCoder-accelerate supports Self-Paced Loss for Convergence Balance; +🔥 MFTCoder-accelerate supports Self-Paced Loss for Convergence Balance. -🔥 MFTCoder-accelerate supports Full-parameters/QLoRA/LoRA using accelerate + DeepSpeed Framework; +🔥 MFTCoder-accelerate supports Full-parameters/QLoRA/LoRA using accelerate + DeepSpeed Framework. 🔥 MFTCoder-accelerate supports Multitask Fine-Tuning(MFT), which is able to balance diffenrent tasks in data level. 🔥 MFTCoder-accelerate supports finetuning most of mainstream open-source base models: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen. ## 2. Data Format -### 2.1 Training Data Format +### 2.1 MFT Training Data Format The training data is required to be a uniformed JSONL format, in which each line of data has the following "chatML"-style JSON format. The "chat_rounds" field is required, and other fields can be added or removed based on specific needs. The reason why we selected "chatML" style as our training and inference data format is that "chatML" style is compatible with both "conversation" and "instruction/response" scenarios. @@ -57,7 +63,7 @@ For the keys of roles in "chat_rounds", you could use "system/human/bot" tuple o } ``` -### 2.2 Default Inference Data Format +### 2.2 Default MFTCoder Inference Template Inference data format is the real string format consumed by tokenizers and then LLMs. It is also the string format to which the training data is converted before tokenization. The default inference data format contains strings concatenated by conversation data(system, human and bot contents) in the training data format. It is used as the data "seen"(before tokenization) by the model in training process. @@ -87,6 +93,56 @@ User nth round input ``` When applying inference, you always make your input string end with ```(削除) bot\n``` to request the model generating answers. +### 2.3 DPO训练数据格式 +The training data is required to be a uniformed JSONL format, in which each line of data has the following JSON format. The "chosen" and "rejected" fields are required as ```chosen``` and ```rejected``` in DPO training and both includes "chatml-style" contents(only last content of bot differs). +```json +{ + "chosen":[ + { + "role": "system", + "content": "You are a expert in coding and help answer code questions" + }, + { + "role": "human", + "content": "Write a python function of quick sort" + }, + { + "role": "bot", + "content": "Below is the function of quick sort: ..." + }, + { + "role": "human", + "content": "Explain the code" + }, + { + "role": "bot", + "content": "OK, this code ..." + } + ], + "rejected":[ + { + "role": "system", + "content": "You are a expert in coding and help answer code questions" + }, + { + "role": "human", + "content": "Write a python function of quick sort" + }, + { + "role": "bot", + "content": "Below is the function of quick sort: ..." + }, + { + "role": "human", + "content": "Explain the code" + }, + { + "role": "bot", + "content": "Sorry, I can not answer..." + } + ] +} +``` ## 3. Model Training @@ -114,6 +170,12 @@ mftcoder_accelerate | *pefts* | + *xxpo* + | + *mpt* + | + *offline_tokenization* + | tokenizer | utils @@ -122,7 +184,11 @@ mftcoder_accelerate ``` 我们将训练中使用的各种组件抽取出来,以便后续的扩展和优化, 详见```src```目录下的实现。 -训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` +MFT训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` + +DPO/ORPO训练入口文件是```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py``` + +MPT(全量加训)训练入口文件是```mftcoder_accelerate/src/mpt/mpt_accelerate.py``` 参数配置存储在```mftcoder_accelerate/src/configs```目录下,方便统一管理和更改。 @@ -131,8 +197,13 @@ mftcoder_accelerate cd mftcoder_accelerate/src ``` -You can find the implementations in the ```mftcoder_accelerate/src``` directory. -The entry directory for fine-tuning training is ```mftcoder_accelerate/src```, and the entry file for training is ```mftcoder_accelerate/src/pefts/mft_accelerate.py```. +You can find the implementations in the ```mftcoder_accelerate/src``` directory +The entry file for MFT training is ```mftcoder_accelerate/src/pefts/mft_accelerate.py```. + +The entry file for DPO/ORPO training is ```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py```. + +The entry file for MPT(Continue Training) is ```mftcoder_accelerate/src/mpt/mpt_accelerate.py```. You need finish offline tokenization of your data via ```mftcoder_accelerate/src/run_offline_tokenization.sh```, which is different from the online tokenizaion used in MFT/DPO. + Configurations are stored in the ```mftcoder_accelerate/src/configs``` directory for easy management and modification. **_As a result, before you start training, you should first change your dir by_** @@ -140,7 +211,7 @@ Configurations are stored in the ```mftcoder_accelerate/src/configs``` directory cd mftcoder_accelerate/src ``` -### 3.1 Tokenization +### 3.1 MFT Tokenization During training, we concatenate multi-turn dialogues into the following format (also known as the inference data format mentioned before) and then tokenize it. In default format, ```human\n``` starts the user's input (i.e., prompt),```bot\n``` starts the assistant's output (i.e., response) @@ -271,6 +342,17 @@ Frequently used arguments are provided in ```configs/***_train_config``` and exp - **role_markers**: {"system": "\system\n", "user": "\human\n", "assistant": "\bot\n} as default(null). You could set your preferred role_markers as the templates startting "system", "user" and "assistant". e.g. {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} +#### CoBa Arguments Configuration +- **coba_warmup_steps**: The number of warm-up steps for CoBa. During the warm-up period, all task weights are equal, and after the warm-up, weights begin to be adjusted dynamically. It is generally recommended to set this close to the total number of validation batches. +- **coba_history_length**: The historical window length of validation loss maintained by CoBa, used to fit the convergence slope at the current step. It is generally recommended to set this between 2 times and 5 times the **coba_warmup_steps**. Typically, the larger this value, the smaller the changes in weights will be. +- **coba_tau**: The temperature coefficient for the Divergence Factor (DF). It is generally set to 5. +- **coba_update_interval**: The frequency at which CoBa updates weights. It is commonly set to 1, meaning weights are updated at every step. +- **coba_sample_valid_num**: The number of validation batches to be sampled by CoBa at each step. Theoretically, when this value equals the total number of validation batches, the fitted convergence slope most closely approximates the actual situation. However, considering computational requirements, it is recommended to set it to 1. + +#### DPO Arguments Configuration +- **xxpo**: preference optimization type, "dpo" or "orpo". +- **beta**: DPO beta, smaller beta allows larger distance between dpo model and ref model. +- **rpo_alpha**: The coefficient of the ```chosen``` NLL loss added to dpo loss. ## 4. Model Usage diff --git a/mftcoder_accelerate/README_cn.md b/mftcoder_accelerate/README_cn.md index 0acb8d8..39631c5 100644 --- a/mftcoder_accelerate/README_cn.md +++ b/mftcoder_accelerate/README_cn.md @@ -7,24 +7,30 @@ [**中文**] [[English]](README.md) ## 1. 更新 +🔥 MFTCoder-accelerate 增加了xxpo模块,支持dpo训练。 + +🔥 MFTCoder-accelerate 增加了mpt模块,借助offline_tokenization模块,支持全量参数加训。 + +🔥 MFTCoder-accelerate 增加了CoBa Loss的最新实现(原selfpaced Loss), 让收敛均衡更进一步。 + 🔥 MFTCoder-accelerate 最新支持的训练模式包括: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, 全量 + DeepSpeed ZeRO3, QLoRA + FSDP, 全量 + FSDP。 -🔥 MFTCoder-accelerate 新增支持QLoRA + DeepSpeed ZeRO3, 支持QLoRA + FSDP, 可以训练更大的模型; +🔥 MFTCoder-accelerate 新增支持QLoRA + DeepSpeed ZeRO3, 支持QLoRA + FSDP, 可以训练更大的模型。 -🔥 MFTCoder-accelerate 新增支持accelerate + FSDP框架, 支持全量微调和LoRA; +🔥 MFTCoder-accelerate 新增支持accelerate + FSDP框架, 支持全量微调和LoRA。 -🔥 MFTCoder-accelerate 支持最新更多主流开源模型: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3; +🔥 MFTCoder-accelerate 支持最新更多主流开源模型: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3。 -🔥 MFTCoder-accelerate 新增self-paced Loss, 用于收敛均衡; +🔥 MFTCoder-accelerate 新增self-paced Loss, 用于收敛均衡。 -🔥 MFTCoder-accelerate 支持使用accelerate + DeepSpeed框架下支持 全量参数/QLoRA/LoRA微调; +🔥 MFTCoder-accelerate 支持使用accelerate + DeepSpeed框架下支持 全量参数/QLoRA/LoRA微调。 -🔥 MFTCoder-accelerate 在训练中支持了多任务微调MFT, 可以同时平衡多个任务的训练,训练的模型支持多任务推理; +🔥 MFTCoder-accelerate 在训练中支持了多任务微调MFT, 可以同时平衡多个任务的训练,训练的模型支持多任务推理。 -🔥 MFTCoder-accelerate 在训练中支持多种模型基座: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen等 +🔥 MFTCoder-accelerate 在训练中支持多种模型基座: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen等。 ## 2. 数据格式 -### 2.1 训练数据格式 +### 2.1 MFT训练数据格式 训练数据为jsonl格式,每一行的数据格式如下,其中chat_rounds字段是必需的,可以根据实际需求添加或删除其他字段。 可以参考项目中的xxx.jsonl文件。 ```json @@ -80,6 +86,57 @@ """ ``` +### 2.3 DPO训练数据格式 +训练数据为jsonl格式,每一行的数据格式如下,其中chosen字段和rejected字段分别代表偏好对齐中的```chosen```和```rejected```,其内部依然是MFT的chatml格式,并且只有最后一轮对话的bot content不同。 +```json +{ + "chosen":[ + { + "role": "system", + "content": "你是一个智能代码助手,可以回复用户与代码相关的问题" + }, + { + "role": "human", + "content": "写一个快速排序" + }, + { + "role": "bot", + "content": "以下是一个快速排序算法xxxxxx" + }, + { + "role": "human", + "content": "解释一下这段代码" + }, + { + "role": "bot", + "content": "好的,这段代码xxx" + } + ], + "rejected":[ + { + "role": "system", + "content": "你是一个智能代码助手,可以回复用户与代码相关的问题" + }, + { + "role": "human", + "content": "写一个快速排序" + }, + { + "role": "bot", + "content": "以下是一个快速排序算法xxxxxx" + }, + { + "role": "human", + "content": "解释一下这段代码" + }, + { + "role": "bot", + "content": "对不起,我不会" + } + ] +} +``` + ## 3. 模型训练 目前支持全量参数(Full-parameters)指令微调、QLoRA指令微调,LoRA指令微调。 @@ -104,6 +161,12 @@ mftcoder_accelerate | *pefts* | + *xxpo* + | + *mpt* + | + *offline_tokenization* + | tokenizer | utils @@ -112,7 +175,11 @@ mftcoder_accelerate ``` 我们将训练中使用的各种组件抽取出来,以便后续的扩展和优化, 详见```src```目录下的实现。 -训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` +MFT训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` + +DPO/ORPO训练入口文件是```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py``` + +MPT(全量加训)训练入口文件是```mftcoder_accelerate/src/mpt/mpt_accelerate.py```. MPT加训需要提前做好数据的tokenziation,通过```mftcoder_accelerate/src/run_offline_tokenization.sh```,你可以将数据通过cpu进行离线的tokenization。这和MFT/DPO中使用的在线tokenziation不同。 参数配置存储在```mftcoder_accelerate/src/configs```目录下,方便统一管理和更改。 @@ -124,7 +191,7 @@ cd mftcoder_accelerate/src ### 3.1 数据tokenization -训练时,我们将多轮对话拼接成如下格式(也是上文中的推理数据格式),然后进行tokenize。 +MFT/DPO训练时,我们将多轮对话拼接成如下格式(也是上文中的推理数据格式),然后进行tokenize。 其中,默认情况下: ```human\n```作为human/user的起始符,```bot\n```作为bot/assistant的起始符,```{EOS_TOKEN}``` 表示eos_token。 @@ -217,6 +284,18 @@ _**训练需要的参数配置在```configs/*_train_config```中,主要参数 - **saving_limit**:整数,ckpt存储数量上限, 全量训练必须设置。默认null即不限制数量。 - **role_markers**: null,即使用{"system": "\system\n", "user": "\human\n", "assistant": "\bot\n"}。 你可以自定义 "system", "user" and "assistant"的模板, 用于定制自己的问答或者对话模板,比如 {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} +#### CoBa相关参数配置 +- **coba_warmup_steps**: CoBa的warm-up步数。在warm-up期间,各任务权重相等,warm-up之后,开始动态调整权重。一般建议设置为与valid batch总数量相近即可。 +- **coba_history_length**: CoBa维护的valid loss的历史窗口长度,用于拟合当前步收敛斜率。一般建议设置为2倍**coba_warmup_steps**至5倍**coba_warmup_steps**之间。一般该值越大,权重的变化幅度就会越小。 +- **coba_tau**: 发散因子(DF)的温度系数。一般设置为5即可。 +- **coba_update_interval**: CoBa更新权重的频率。一般设置为1,即每一步都对权重做更新。 +- **coba_sample_valid_num**: CoBa每一步要取的valid batch数。理论上当该值等于valid batch总数量时,拟合出的收敛斜率最逼近真实情况,但考虑到计算需求,建议设置为1。 + +#### DPO 相关参数配置 +- **xxpo**: 偏好对齐方法, "dpo" 或者 "orpo"。 +- **beta**: DPO beta, beta 越小,允许对齐后的dpo模型与ref模型的距离越远。 +- **rpo_alpha**: 加到dop损失的```chosen``` NLL损失的系数,0的话就是原始DPO。 +- ## 4. 模型使用 ### 4.1 权重合并 diff --git a/mftcoder_accelerate/inference/hf_inference.py b/mftcoder_accelerate/inference/hf_inference.py index 16b8933..67f9ba0 100644 --- a/mftcoder_accelerate/inference/hf_inference.py +++ b/mftcoder_accelerate/inference/hf_inference.py @@ -2,91 +2,85 @@ # @author Chaoyu Chen # @date 2024年1月4日 # @module hf_inference.py - +""" +# @author qumu +# @date 2023年9月19日 +# @module hf_inference.py +""" import os import sys import torch import textwrap -from transformers import ( - AutoConfig, - AutoTokenizer, - AutoModelForCausalLM, - StoppingCriteria, - StoppingCriteriaList -) +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList from peft import PeftModel -def load_model_tokenizer(path, model_type=None, peft_path=None, torch_dtype=torch.bfloat16, quantization=None, - eos_token=None, pad_token=None): +def load_model_tokenizer( + path, + model_type=None, + peft_path=None, + torch_dtype=torch.bfloat16, + quantization=None, + eos_token=None, + pad_token=None, + batch_size=1, +): """ - load model and tokenizer by transfromers + load model and tokenizer by transfromers """ # load tokenizer first tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) tokenizer.padding_side = "left" - config, unused_kwargs = AutoConfig.from_pretrained( - path, - trust_remote_code=True, - return_unused_kwargs=True - ) + config, unused_kwargs = AutoConfig.from_pretrained(path, trust_remote_code=True, return_unused_kwargs=True) print("unused_kwargs:", unused_kwargs) print("config input:\n", config) - # eos token优先级: 1. 用户输入eos_token 2. config中的eos_token_id 3. config中的eos_token + # eos token parsing if eos_token: eos_token = eos_token eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) print(f"eos_token {eos_token} from user input") + elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: + print(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer") + eos_token_id = tokenizer.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(eos_token_id) + elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token: + print(f"Initial eos_token {tokenizer.eos_token} from tokenizer") + eos_token = tokenizer.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) + elif hasattr(config, "eos_token_id") and config.eos_token_id: + print(f"Initial eos_token_id {config.eos_token_id} from config.json") + eos_token_id = config.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id) + elif hasattr(config, "eos_token") and config.eos_token: + print(f"Initial eos_token {config.eos_token} from config.json") + eos_token = config.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token) else: - if hasattr(config, "eos_token_id") and config.eos_token_id: - print(f"eos_token_id {config.eos_token_id} from config.json") - eos_token_id = config.eos_token_id - eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id) - elif hasattr(config, "eos_token") and config.eos_token: - print(f"eos_token {config.eos_token} from config.json") - eos_token = config.eos_token - eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token) - else: - raise ValueError( - "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json") - - # pad token优先级: 1. 用户输入 pad_token 2. config中的pad_token_id 3. config中的pad_token - if pad_token: - pad_token = pad_token - pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) - print(f"pad_token {pad_token} from user input") - else: - if hasattr(config, "pad_token_id") and config.pad_token_id: - print(f"pad_token_id {config.pad_token_id} from config.json") - pad_token_id = config.pad_token_id - pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) - elif hasattr(config, "pad_token") and config.pad_token: - print(f"pad_token {config.pad_token} from config.json") - pad_token = config.pad_token - pad_token_id = tokenizer.convert_tokens_to_ids(config.pad_token) - else: - print(f"pad_token {eos_token} duplicated from eos_token") - pad_token = eos_token - pad_token_id = eos_token_id - - # update tokenizer eos_token and pad_token - tokenizer.eos_token_id = eos_token_id - tokenizer.eos_token = eos_token - tokenizer.pad_token_id = pad_token_id - tokenizer.pad_token = pad_token + raise ValueError( + "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json" + ) + + try: + tokenizer.eos_token = eos_token + tokenizer.eos_token_id = eos_token_id + # set pad_token to be same as eos_token, it is ok because is will be masked out. + tokenizer.pad_token = eos_token + tokenizer.pad_token_id = eos_token_id + except: + print(f"[WARNING]Cannot set tokenizer.eos_token") print(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}") print(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}") - print(tokenizer) + print(type(tokenizer)) base_model = AutoModelForCausalLM.from_pretrained( path, config=config, - load_in_8bit=(quantization == '8bit'), - load_in_4bit=(quantization == '4bit'), + load_in_8bit=(quantization == "8bit"), + load_in_4bit=(quantization == "4bit"), device_map="auto", torch_dtype=torch_dtype, trust_remote_code=True, @@ -114,9 +108,10 @@ def load_model_tokenizer(path, model_type=None, peft_path=None, torch_dtype=torc def hf_inference(model, tokenizer, text_list, args=None, max_new_tokens=512, do_sample=True, **kwargs): """ - transformers models inference by huggingface + transformers models inference by huggingface """ - inputs = tokenizer(text_list, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda") + # text_list = [tokenizer.apply_chat_template([{"role": "user", "content": text}], tokenize=False) for text in text_list] + inputs = tokenizer(text_list, return_tensors="pt", padding=True, add_special_tokens=False).to("cuda") # inputs["attention_mask"][0][:100] = 0 # print(inputs) print("================================Prompts and Generations=============================") @@ -128,15 +123,15 @@ def hf_inference(model, tokenizer, text_list, args=None, max_new_tokens=512, do_ do_sample=do_sample, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, - **kwargs + **kwargs, ) - gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True) + gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True) for i in range(len(text_list)): - print('=========' * 10) - print(f'Prompt:\n{text_list[i]}') - gen_text[i] = gen_text[i].replace(tokenizer.pad_token, '') - print(f'Generation:\n{gen_text[i]}') + print("=========" * 10) + print(f"Prompt:\n{text_list[i]}") + gen_text[i] = gen_text[i].replace(tokenizer.pad_token, "") + print(f"Generation:\n{gen_text[i]}") # print(f"Outputs ids:\n{outputs[i]}") sys.stdout.flush() @@ -155,11 +150,9 @@ def hf_inference(model, tokenizer, text_list, args=None, max_new_tokens=512, do_ # if you use base + adaptor for inference, provide peft_path or left it None for normal inference base_model = "path/to/basemodel" peft_path = None - model, tokenizer = load_model_tokenizer(base_model, - model_type='', - peft_path=peft_path, - eos_token=' (削除ここまで)', - pad_token='') + model, tokenizer = load_model_tokenizer( + base_model, model_type="", peft_path=peft_path, eos_token="", pad_token="" + ) # hf_inference(model, tokenizer, prompts, do_sample=False, num_beams=1, num_return_sequences=1) hf_inference(model, tokenizer, prompts, do_sample=True, temperature=0.8) diff --git a/mftcoder_accelerate/src/configs/selfpaced_train_config.json b/mftcoder_accelerate/src/configs/coba_train_config.json similarity index 78% rename from mftcoder_accelerate/src/configs/selfpaced_train_config.json rename to mftcoder_accelerate/src/configs/coba_train_config.json index 98007e7..63167f1 100644 --- a/mftcoder_accelerate/src/configs/selfpaced_train_config.json +++ b/mftcoder_accelerate/src/configs/coba_train_config.json @@ -5,16 +5,17 @@ "pretrained_model_path": "$MODEL_NAME_OR_PATH", "model_type": "$MODEL_TYPE", "load_raw_dataset": true, - "data_split": "98,2,0", + "data_split": "95,5,0", "padding_mode": "padding", "use_dynamic_padding": true, "tokenize_mode": "sft", "tokenizer_type": "AutoTokenizer", - "weighted_loss_mode": "selfpaced", - "selfpaced_interval": 1, - "selfpaced_history_length": 100, - "selfpaced_sample_valid_num": 1, - "selfpaced_scale_factor": 50, + "weighted_loss_mode": "coba", + "coba_warmup_steps": 100, + "coba_history_length": 200, + "coba_tau": 5, + "coba_update_interval": 1, + "coba_sample_valid_num": 1, "attn_implementation": "flash_attention_2", "seq_length": 4096, "seed": 1234, @@ -23,8 +24,8 @@ "lora_rank": 96, "lora_alpha": 32, "lora_dropout": 0.05, - "per_device_train_batch_size": 2, - "per_device_eval_batch_size": 2, + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, "learning_rate": 5e-5, "min_lr": 5e-6, "weight_decay": 0.1, @@ -42,4 +43,4 @@ "early_stopping": true, "early_stopping_stall_num": 5, "saving_limit": null -} \ No newline at end of file + } \ No newline at end of file diff --git a/mftcoder_accelerate/src/configs/dpo_train_config.json b/mftcoder_accelerate/src/configs/dpo_train_config.json new file mode 100644 index 0000000..5a93db9 --- /dev/null +++ b/mftcoder_accelerate/src/configs/dpo_train_config.json @@ -0,0 +1,34 @@ +{ + "xxpo": "dpo", + "data_paths": "$DATA_PATHS", + "output_dir": "$OUTPUT_DIR", + "tb_dir": "$TensorBoard_DIR", + "pretrained_model_path": "$MODEL_NAME_OR_PATH", + "model_type": "$MODEL_TYPE", + "data_split": "99,1", + "attn_implementation": "flash_attention_2", + "beta": 0.1, + "rpo_alpha": 0.5, + "peft_type": "lora", + "lora_rank": 64, + "lora_alpha": 128, + "lora_dropout": 0.0, + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, + "tokenizer_type": "AutoTokenizer", + "dataset_num_proc": 1, + "learning_rate": 5e-7, + "weight_decay": 0.01, + "gradient_accumulation_steps": 8, + "lr_scheduler_type": "cosine", + "warmup_steps": 100, + "num_train_epochs": 2, + "seed": 1105, + "max_prompt_length": 2048, + "max_length": 4096, + "logging_steps": 20, + "save_steps": 500, + "eval_steps": 500, + "epoch_checkpointing": false, + "saving_limit": 5 +} \ No newline at end of file diff --git a/mftcoder_accelerate/src/data/blendable_dataset.py b/mftcoder_accelerate/src/data/blendable_dataset.py index 3dd6139..84b9756 100644 --- a/mftcoder_accelerate/src/data/blendable_dataset.py +++ b/mftcoder_accelerate/src/data/blendable_dataset.py @@ -43,7 +43,7 @@ def __init__(self, datasets, weights): # recompute weights weights = self.calc_weights() - + # Build indices. start_time = time.time() assert num_datasets < 255 @@ -63,9 +63,7 @@ def __init__(self, datasets, weights): print( "> RANK {} elapsed time for building blendable dataset indices: " - "{:.2f} (sec)".format( - torch.distributed.get_rank(), time.time() - start_time - ) + "{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time) ) def calc_weights(self): @@ -73,7 +71,7 @@ def calc_weights(self): total_cnt = sum(dataset_sample_cnt) weights = np.array([(cnt + 0.0) / total_cnt for cnt in dataset_sample_cnt], dtype=np.float64) return weights - + def __len__(self): return self.size diff --git a/mftcoder_accelerate/src/data/data_utils.py b/mftcoder_accelerate/src/data/data_utils.py index fa79f32..8d168bd 100644 --- a/mftcoder_accelerate/src/data/data_utils.py +++ b/mftcoder_accelerate/src/data/data_utils.py @@ -32,10 +32,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): start_time = time.time() indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) - print_rank_0( - "> finished creating indexed dataset in {:4f} " - "seconds".format(time.time() - start_time) - ) + print_rank_0("> finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) return indexed_dataset @@ -53,20 +50,22 @@ def build_train_valid_test_datasets( build_index_mappings=True, shuffle_before_split=False, weighted_loss_mode=None, - ds_weights=[1., 1., 1.], - train_mode='sft', + ds_weights=[1.0, 1.0, 1.0], + train_mode="sft", ): """Build train, valid, and test datasets.""" # Indexed dataset. - assert os.path.exists(data_prefix + "_input_ids.bin"), f"Input tokens datafile not found: {data_prefix}_input_ids.bin" + assert os.path.exists( + data_prefix + "_input_ids.bin" + ), f"Input tokens datafile not found: {data_prefix}_input_ids.bin" # Indexed dataset. input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup) - if train_mode == 'sft': + if train_mode == "sft": loss_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_loss_mask", data_impl, skip_warmup) else: - print(f'pretrain mode, loss mask is ones') + print(f"pretrain mode, loss mask is ones") loss_mask_indexed_dataset = None total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0] @@ -79,9 +78,7 @@ def print_split_stats(name, index): print_rank_0(" {}:".format(name)) print_rank_0( " document indices in [{}, {}) total of {} " - "documents".format( - splits[index], splits[index + 1], splits[index + 1] - splits[index] - ) + "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) ) print_split_stats("train", 0) @@ -100,11 +97,9 @@ def build_dataset(index, name, ds_weight=1.0): dataset = None if splits[index + 1]> splits[index]: if shuffle_before_split: - documents = shuffle_doc_index[splits[index]:splits[index + 1]] + documents = shuffle_doc_index[splits[index] : splits[index + 1]] else: - documents = np.arange( - start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 - ) + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) dataset = GPT2PromptDataset( name, @@ -130,11 +125,13 @@ def build_dataset(index, name, ds_weight=1.0): return train_dataset, valid_dataset, test_dataset, total_num_of_documents -def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, use_shared_fs=True, data_impl="mmap", mmap_warmup=False): +def build_multiple_train_valid_test_datasets( + args, train_valid_test_num_samples, use_shared_fs=True, data_impl="mmap", mmap_warmup=False +): """Build multiple train, valid, and test datasets.""" - data_prefixes = list(args.data_paths[1:-1].split(',')) + data_prefixes = list(args.data_paths[1:-1].split(",")) - data_weights = list(map(float, args.data_weights[1:-1].split(','))) + data_weights = list(map(float, args.data_weights[1:-1].split(","))) print("data weights: ") print(data_weights) use_shared_fs = use_shared_fs @@ -143,7 +140,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, seq_length = args.seq_length # seq_length = args.block_size seed = args.seed - skip_warmup = (not mmap_warmup) + skip_warmup = not mmap_warmup weight_by_num_documents = args.weight_by_num_documents shuffle_before_split = args.shuffle_before_split weighted_loss_mode = args.weighted_loss_mode @@ -183,9 +180,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, factor = 1 if weight_by_num_documents: # gets the number of documents in each data path - get_num_docs_list = lambda datasets: [ - dataset.input_ids_indexed_dataset.sizes.shape[0] for dataset in datasets - ] + get_num_docs_list = lambda datasets: [dataset.input_ids_indexed_dataset.sizes.shape[0] for dataset in datasets] train_num_docs, valid_num_docs, test_num_docs = ( get_num_docs_list(train_datasets), get_num_docs_list(valid_datasets), @@ -201,7 +196,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, ) assert sum(train_weights) != 0.0, "found train weights to be 0.0" assert sum(valid_weights) != 0.0, "found valid weights to be 0.0" - + train_weights, train_num_samples = get_normalized_weights_and_num_samples( train_weights, train_valid_test_num_samples[0] ) @@ -265,7 +260,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, if num_tokens: factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length) factor /= sum([1.0 / w for w in train_ds_weights]) / len(train_ds_weights) - + print_rank_0(f"> common denomination factor for CE loss: {factor}") # Blend. @@ -274,7 +269,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, i = 0 for ds in train_datasets: ds.update_ds_weight(ds.ds_weight / factor) - print(f'loss weight of dataset {i} after update: {ds.ds_weight}') + print(f"loss weight of dataset {i} after update: {ds.ds_weight}") i += 1 blending_train_dataset = BlendableDataset(train_datasets, train_weights) blending_valid_dataset = None @@ -318,9 +313,7 @@ def get_train_valid_test_split_(splits_string, size): return splits_index -def get_normalized_weights_and_num_samples( - weights: List[float], num_samples: int -) -> Tuple[List[float], List[int]]: +def get_normalized_weights_and_num_samples(weights: List[float], num_samples: int) -> Tuple[List[float], List[int]]: # Normalize weights weight_sum = sum(weights) assert weight_sum> 0.0 @@ -346,12 +339,7 @@ def get_datasets_normalized_weights_and_num_samples( # samples left to feed to the network. weighted_num_samples = [] for weight in weights: - weighted_num_samples.append( - [ - int(math.ceil(val * weight * 1.005)) - for val in num_samples - ] - ) + weighted_num_samples.append([int(math.ceil(val * weight * 1.005)) for val in num_samples]) return weights, weighted_num_samples diff --git a/mftcoder_accelerate/src/data/gpt2_dataset.py b/mftcoder_accelerate/src/data/gpt2_dataset.py index 05aa632..12eeb87 100644 --- a/mftcoder_accelerate/src/data/gpt2_dataset.py +++ b/mftcoder_accelerate/src/data/gpt2_dataset.py @@ -41,7 +41,7 @@ def __init__( use_shared_fs=True, weighted_loss_mode=None, ds_weight=1.0, - train_mode='sft', + train_mode="sft", ): self.name = name @@ -50,9 +50,9 @@ def __init__( self.weighted_loss_mode = weighted_loss_mode self.ds_weight = ds_weight - - self.task_name = data_prefix.split('/')[-1] - + + self.task_name = data_prefix.split("/")[-1] + self.task_id = TASK2ID[self.task_name] # Checks @@ -114,14 +114,10 @@ def __getitem__(self, idx): else: # Otherwise, get the rest of the initial document. - input_ids_list = [ - self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + input_ids_list = [self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] if self.loss_mask_indexed_dataset is not None: - loss_mask_list = [ - self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + loss_mask_list = [self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] else: loss_mask_list = [] @@ -133,16 +129,12 @@ def __getitem__(self, idx): # And finally add the relevant portion of last document. input_ids_list.append( - self.input_ids_indexed_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) + self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) ) if self.loss_mask_indexed_dataset is not None: loss_mask_list.append( - self.loss_mask_indexed_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) + self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) ) input_ids = np.concatenate(input_ids_list) @@ -246,18 +238,12 @@ def __getitem__(self, idx): ) else: # Otherwise, get the rest of the initial document. - sample_list = [ - self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) # And finally add the relevant portion of last document. - sample_list.append( - self.indexed_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) - ) + sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) sample = np.concatenate(sample_list) return {"text": np.array(sample, dtype=np.int64)} @@ -313,10 +299,7 @@ def _build_index_mappings( or (not os.path.isfile(sample_idx_filename)) or (not os.path.isfile(shuffle_idx_filename)) ): - print_rank_0( - "> WARNING: could not find index map files, building " - "the indices on rank 0 ..." - ) + print_rank_0("> WARNING: could not find index map files, building " "the indices on rank 0 ...") # doc-idx. start_time = time.time() doc_idx = _build_doc_idx(documents, num_epochs, np_rng) @@ -338,13 +321,9 @@ def _build_index_mappings( # 我理解这里的num_samples应该是和入参的num_samples重名,这里只是为了计算构建所有索引的长度,从而决定是用int64还是int32 num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length if 2 * (num_samples + 1) < np.iinfo(np.int32).max: - sample_idx = helpers.build_sample_idx_int32( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch - ) + sample_idx = helpers.build_sample_idx_int32(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) else: - sample_idx = helpers.build_sample_idx_int64( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch - ) + sample_idx = helpers.build_sample_idx_int64(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) np.save(sample_idx_filename, sample_idx, allow_pickle=True) print_rank_0( "> elapsed time to build and save sample-idx mapping " @@ -360,7 +339,7 @@ def _build_index_mappings( "> elapsed time to build and save shuffle-idx mapping" " (seconds): {:4f}".format(time.time() - start_time) ) - + torch.distributed.barrier() # TODO: model parallel # This should be a barrier but nccl barrier assumes @@ -370,7 +349,7 @@ def _build_index_mappings( # torch.distributed.all_reduce(counts) # torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) # assert counts[0].item() == torch.distributed.get_world_size( - # group=mpu.get_io_parallel_group() + # group=mpu.get_io_parallel_group() # ) # Load mappings. @@ -381,9 +360,7 @@ def _build_index_mappings( sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") print_rank_0("> loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") - print_rank_0( - " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) - ) + print_rank_0(" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)) print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) print_rank_0(" total number of epochs: {}".format(num_epochs)) diff --git a/mftcoder_accelerate/src/data/indexed_dataset.py b/mftcoder_accelerate/src/data/indexed_dataset.py index 9a26379..12ea9c2 100644 --- a/mftcoder_accelerate/src/data/indexed_dataset.py +++ b/mftcoder_accelerate/src/data/indexed_dataset.py @@ -44,17 +44,13 @@ def infer_dataset_impl(path): return None else: print(f"Dataset does not exist: {path}") - print( - "Path should be a basename that both .idx and .bin can be appended to get full filenames." - ) + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": - return MMapIndexedDatasetBuilder( - out_file, dtype=__best_fitting_dtype(vocab_size) - ) + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) @@ -62,9 +58,7 @@ def make_builder(out_file, impl, vocab_size=None): def make_dataset(path, impl, skip_warmup=False): if not IndexedDataset.exists(path): print(f"Dataset does not exist: {path}") - print( - "Path should be a basename that both .idx and .bin can be appended to get full filenames." - ) + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None if impl == "infer": impl = infer_dataset_impl(path) @@ -145,8 +139,7 @@ def read_index(self, path): with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( - "Index file doesn't match expected format. " - "Make sure that --dataset-impl is configured properly." + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." ) version = f.read(8) assert struct.unpack(" 0: - self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 - + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + def add_data(self, data): self.data.append(data) - + def commit(self, archive_name=None): # TODO: streaming cctx = zstandard.ZstdCompressor(level=3) @@ -354,15 +373,18 @@ def commit(self, archive_name=None): if archive_name is None: archive_name = str(int(time.time())) - res = b''.join(map(lambda x: ("%016d" % len(x)).encode('UTF-8') + x, map(lambda x: x.encode('UTF-8'), self.data))) + res = b"".join( + map(lambda x: ("%016d" % len(x)).encode("UTF-8") + x, map(lambda x: x.encode("UTF-8"), self.data)) + ) cdata = cctx.compress(res) - with open(self.out_dir + '/data_' + str(self.i) + '_' + archive_name + '.dat.zst', 'wb') as fh: + with open(self.out_dir + "/data_" + str(self.i) + "_" + archive_name + ".dat.zst", "wb") as fh: fh.write(cdata) self.i += 1 self.data = [] + class JSONArchive: def __init__(self, out_dir): self.out_dir = out_dir @@ -370,17 +392,17 @@ def __init__(self, out_dir): self.data = [] self.i = 0 if os.path.exists(out_dir) and len(os.listdir(out_dir))> 0: - self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 - + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + def add_data(self, data): self.data.append(data) - + def commit(self): cctx = zstandard.ZstdCompressor(level=3) - - cdata = cctx.compress(json.dumps(self.data).encode('UTF-8')) - with open(self.out_dir + '/data_' + str(self.i) + '_' + str(int(time.time())) + '.json.zst', 'wb') as fh: + + cdata = cctx.compress(json.dumps(self.data).encode("UTF-8")) + with open(self.out_dir + "/data_" + str(self.i) + "_" + str(int(time.time())) + ".json.zst", "wb") as fh: fh.write(cdata) self.i += 1 - self.data = [] \ No newline at end of file + self.data = [] diff --git a/mftcoder_accelerate/src/data/multi_task_dataset.py b/mftcoder_accelerate/src/data/multi_task_dataset.py index fde298b..63c4b27 100644 --- a/mftcoder_accelerate/src/data/multi_task_dataset.py +++ b/mftcoder_accelerate/src/data/multi_task_dataset.py @@ -2,11 +2,14 @@ # @author Chaoyu Chen # @date 2023年8月18日 +Load dataset in a distributed way. """ + import os import json import math import time +import glob import numpy as np import torch from functools import partial @@ -192,18 +195,21 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, # 不同数据集在不同文件夹下 for dataset_index in range(len(data_prefixes)): - files = os.listdir(data_prefixes[dataset_index]) + # files = os.listdir(data_prefixes[dataset_index]) + # get all jsonl files and corresponding reading handler + if data_prefixes[dataset_index].endswith(".jsonl"): + files = [data_prefixes[dataset_index]] + else: + files = glob.glob(os.path.join(data_prefixes[dataset_index], "**/*.jsonl"), recursive=True) + cur_dataset_input_ids = [] cur_dataset_loss_mask = [] # support multiple jsonl files under task dir - for file in files: - file_name = data_prefixes[dataset_index] + "/" + file - if os.path.isdir(file_name): - continue + for file_name in files: fin = open(file_name, "r") print(f"[Global Rank {global_rank}] open file {file_name}") - if args.padding_mode == "padding" or args.padding_mode == "pack": + if args.padding_mode == "padding" or args.padding_mode == "pack" or args.padding_mode == "concat": for i, line in enumerate(fin): # pre-sharding if shard_data and i % world_size != global_rank: @@ -254,7 +260,8 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, cur_train_dataset = {"input_ids": cur_train_input_ids, "loss_mask": cur_train_loss_mask} cur_valid_dataset = {"input_ids": cur_valid_input_ids, "loss_mask": cur_valid_loss_mask} print(f"[Global Rank {global_rank}]shape of cur train dataset: {cur_train_dataset['input_ids'].shape}") - print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") + if local_valid_num> 0: + print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") cur_train_ds = GPT2FromRawDataset( "train", @@ -264,32 +271,32 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, weighted_loss_mode=args.weighted_loss_mode, ds_weight=splits[0], ) - cur_valid_ds = GPT2FromRawDataset( - "valid", - data_prefixes[dataset_index], - cur_valid_dataset, - args.seq_length, - weighted_loss_mode=args.weighted_loss_mode, - ds_weight=splits[1], - ) - all_train_datasets.append(cur_train_ds) - all_valid_datasets.append(cur_valid_ds) all_train_datasets_length.append(len(cur_train_ds)) - all_valid_datasets_length.append(len(cur_valid_ds)) + if local_valid_num> 0: + cur_valid_ds = GPT2FromRawDataset( + "valid", + data_prefixes[dataset_index], + cur_valid_dataset, + args.seq_length, + weighted_loss_mode=args.weighted_loss_mode, + ds_weight=splits[1], + ) + all_valid_datasets.append(cur_valid_ds) + all_valid_datasets_length.append(len(cur_valid_ds)) + else: + cur_valid_ds = None print(f"[Global Rank {global_rank}]num tokens: {num_tokens}") print(f"[Global Rank {global_rank}]effective token rate: {effective_token_rate}") num_tokens = [] ds_fn = partial(ds_weights_by_num_docs_sft) - train_loss_weights, valid_loss_weights = ( - ds_fn(all_train_datasets_length), - ds_fn(all_valid_datasets_length), - ) - + train_loss_weights = ds_fn(all_train_datasets_length) print(f"> train loss weights in rank {global_rank}: {train_loss_weights}") - print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") + if all_valid_datasets_length: + valid_loss_weights = ds_fn(all_valid_datasets_length) + print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") factor = 1 # calcualte common factor based on token cnt and total sample cnt @@ -299,9 +306,10 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, print(f"> common denomination factor for CE loss in rank {global_rank}: {factor}") train_sample_weights = [x / sum(all_train_datasets_length) for x in all_train_datasets_length] - valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] print(f"> train sample weights in rank {global_rank}: {train_sample_weights}") - print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") + if all_valid_datasets_length: + valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] + print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") # recompute global_train_num and global_valid_num @@ -312,22 +320,23 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, global_train_num_samples_tensor = global_train_num_samples_tensor.to(device) torch.distributed.all_reduce(global_train_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) global_train_num = global_train_num_samples_tensor.item() - - global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) - global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) - torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) - global_valid_num = global_valid_num_samples_tensor.item() print(f"> global train num in rank {global_rank}: {global_train_num}") - print(f"> global valid num in rank {global_rank}: {global_valid_num}") + + if local_valid_num> 0: + global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) + global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) + torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) + global_valid_num = global_valid_num_samples_tensor.item() + print(f"> global valid num in rank {global_rank}: {global_valid_num}") torch.distributed.barrier() - for i in range(len(all_train_datasets)): - print( - f"loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}" - ) blending_train_dataset = None if all_train_datasets: + for i in range(len(all_train_datasets)): + print( + f"loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}" + ) args.do_train = True for i in range(len(all_train_datasets)): all_train_datasets[i].update_ds_weight(train_loss_weights[i] / factor) @@ -338,12 +347,12 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, all_train_datasets, train_sample_weights, global_train_num, local_train_num ) - for i in range(len(all_valid_datasets)): - print( - f"loss weight of valid dataset {i} before update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" - ) blending_valid_dataset = None if all_valid_datasets: + for i in range(len(all_valid_datasets)): + print( + f"loss weight of valid dataset {i} before update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" + ) args.do_valid = True for i in range(len(all_valid_datasets)): all_valid_datasets[i].update_ds_weight(valid_loss_weights[i] / factor) diff --git a/mftcoder_accelerate/src/data/preprocess_data.py b/mftcoder_accelerate/src/data/preprocess_data.py index 3c912e6..f7226bd 100644 --- a/mftcoder_accelerate/src/data/preprocess_data.py +++ b/mftcoder_accelerate/src/data/preprocess_data.py @@ -9,7 +9,7 @@ import ftfy import glob -# print("In preprocess_data.py, sys path:", sys.path) +# print("In preprocess_data_new.py, sys path:", sys.path) from tokenizer import build_tokenizer @@ -32,7 +32,7 @@ def content_format(content: str): # change chinese punctuation to english ones # text = text.translate(table) - + # if not content.endswith("\n"): content += "\n" return content @@ -101,6 +101,13 @@ def is_question_answer_format(data): return False +def is_query_answer_format(data): + if "query" in data and "answer" in data: + return True + else: + return False + + class Encoder(object): tokenizer = None @@ -125,7 +132,7 @@ def encode(self, text): if len(text_ids)> 0: doc_ids.append(text_ids) if self.args.append_eod: - doc_ids[-1].append(Encoder.tokenizer.eod_id) + doc_ids[-1].append(Encoder.tokenizer.eos_token_id) ids[key] = doc_ids return ids, len(text) @@ -163,6 +170,8 @@ def encode(self, data, verbose=False): data_type = "question_response" elif is_question_answer_format(data): data_type = "question_answer" + elif is_query_answer_format(data): + data_type = "query_answer" elif is_chatml_format(data): data_type = "chatML" elif is_text_format(data): @@ -209,7 +218,7 @@ def _tokenize_fields(self, data, data_type): else: raise ValueError(f"tokenize_mode does not support {self.mode}, please use sft or pretrain") - sft_end_marker_ids = [Encoder.tokenizer.eod_id] + sft_end_marker_ids = [Encoder.tokenizer.eos_token_id] # uniform SST,SFT,MFT input_ids = [] loss_mask = [] @@ -236,7 +245,7 @@ def _tokenize_fields(self, data, data_type): content_ids = self.pure_encode(user_marker + content + assistant_marker) input_ids += content_ids loss_mask += [0] * len(content_ids) - elif role == "bot" or role == "assistant": + elif role == "bot" or role == "assistant" or role == "gpt": content_ids = self.pure_encode(content) + sft_end_marker_ids input_ids += content_ids loss_mask += [1] * len(content_ids) @@ -324,16 +333,16 @@ def _tokenize_fields(self, data, data_type): yield {} def padding(self, input_ids, loss_mask): - pad_id = Encoder.tokenizer.pad_id + pad_id = Encoder.tokenizer.pad_token_id assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)}> {self.seq_length}" input_ids += [pad_id] * (self.seq_length - len(input_ids)) loss_mask += [0] * (self.seq_length - len(loss_mask)) return {"input_ids": input_ids, "loss_mask": loss_mask} -def find_jsonl_fnames(inputs): +def find_jsonl_fnames(paths): fnames = [] - for p in inputs.split(","): + for p in paths: if not os.path.isdir(p): if p.endswith(".jsonl"): print(f"loading from {p}") diff --git a/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py new file mode 100644 index 0000000..82e0f5d --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py @@ -0,0 +1,206 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python +>>> from transformers import DeepseekV2Model, DeepseekV2Config + +>>> # Initializing a Deepseek-V2 style configuration +>>> configuration = DeepseekV2Config() + +>>> # Accessing the model configuration +>>> configuration = model.config + ```""" + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size = 1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts = None, + n_routed_experts = None, + ep_size = 1, + routed_scaling_factor = 1.0, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'gready', + n_group = None, + topk_group = None, + num_experts_per_tok = None, + moe_layer_freq = 1, + first_k_dense_replace = 0, + norm_topk_prob = False, + scoring_func = 'softmax', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py new file mode 100644 index 0000000..d1d5e88 --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py @@ -0,0 +1,1925 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV2Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV2Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) + + +class DeepseekV2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len> self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len> self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV2MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == "group_limited_greedy": + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + ### norm gate to sum 1 + if self.top_k> 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + ### expert-level computation auxiliary loss + if self.training and self.alpha> 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class DeepseekV2MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size> 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i>= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + # save dtype before computation + input_dtype = hidden_states.dtype + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + # keep dtype same after moe forward + return y.to(input_dtype) + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size> 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size> 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 +class DeepseekV2FlashAttention2(DeepseekV2Attention): + """ + DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + # print(f"dtype of hidden_states: {hidden_states.dtype}") + # print(f"dtype of q_proj: {self.q_proj.weight.dtype}") + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV2Attention, + "flash_attention_2": DeepseekV2FlashAttention2, +} + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV2MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx>= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV2MLP(config) + ) + self.input_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + # print(f"1. dtype of residual: {residual.dtype}") + + hidden_states = self.input_layernorm(hidden_states) + # print(f"2. dtype of hidden_states before attn: {hidden_states.dtype}") + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + # print(f"3. dtype of hidden_states after attn: {hidden_states.dtype}") + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + # print(f"4. dtype of hidden_states after post layernorm: {hidden_states.dtype}") + hidden_states = self.mlp(hidden_states) + # print(f"5. dtype of hidden_states after mlp: {hidden_states.dtype}") + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2PreTrainedModel(PreTrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2Model(DeepseekV2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python +>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + +>>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) +>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + +>>> prompt = "Hey, are you conscious? Can you talk to me?" +>>> inputs = tokenizer(prompt, return_tensors="pt") + +>>> # Generate +>>> generate_ids = model.generate(inputs.input_ids, max_length=30) +>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1]> input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length>= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1]> max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels> 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes> 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels> 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py new file mode 100644 index 0000000..d243771 --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py @@ -0,0 +1,38 @@ +from typing import List, Optional, Union + + +from transformers.models.llama import LlamaTokenizerFast + + +class DeepseekTokenizerFast(LlamaTokenizerFast): + + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + if isinstance(ids, int): + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + token = self._tokenizer.id_to_token(index) + tokens.append(token if token is not None else "") + return tokens + + def _convert_id_to_token(self, index: int) -> Optional[str]: + token = self._tokenizer.id_to_token(int(index)) + return token if token is not None else "" diff --git a/mftcoder_accelerate/src/mpt/mpt_accelerate.py b/mftcoder_accelerate/src/mpt/mpt_accelerate.py new file mode 100644 index 0000000..5d187c9 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_accelerate.py @@ -0,0 +1,494 @@ +""" +# @author Chaoyu Chen +# @date 2024年6月1日 +# @module mpt_accelerate.py + +Accelerate + DeepSpeed + Full-parameter + Multi-task + Pre-training/Continue Training/Finetuning + +Entry +""" + +import os +import sys +import argparse +import math +import logging +import json +import time +from tqdm.auto import tqdm +import transformers +import numpy as np +import torch +from torch import nn +from dataclasses import dataclass +from datasets import Dataset +import datasets +from torch.utils.data import DataLoader +from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + set_seed, + BitsAndBytesConfig, + get_scheduler, +) + +from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration +from accelerate.logging import get_logger +from datetime import timedelta +from accelerate.utils import InitProcessGroupKwargs +from transformers.optimization import Adafactor + +# insert src as import path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +sys.path.insert(0, parent_dir) + +from tokenizer import build_tokenizer +from data.multi_task_dataset import load_dataset_from_jsonl, compile_helper +from data.data_utils import load_dataset_from_bin +from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK +from mpt.mpt_trainer import MptTrainer +from mpt.mpt_arguments import MptTrainArgs +from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS + + +logger = get_logger(__name__) + + +def get_task_mask(args, task_id): + task_num = len(TASK2ID) + task_mask = torch.zeros(task_id.shape[0], task_num) + task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1 + + return task_mask + + +def get_attention_mask_and_position_ids(data): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + attention_mask = torch.ones((batch_size, seq_length), device=data.device) + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data).clone() + + return attention_mask, position_ids + + +@dataclass +class DataCollatorForMFTDataset(object): + args: None + + def __call__(self, instances): + (input_ids, loss_mask, weights, task_id) = tuple( + [instance.get(key, None) for instance in instances] + for key in ("input_ids", "loss_mask", "weight", "task_id") + ) + + result_batch = {} + """ + outputs = model( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']), + # labels=(batch['labels'], batch['loss_mask']), + position_ids=batch['position_ids']) + """ + + # if loss_mask is not None: + loss_mask = torch.tensor(np.array(loss_mask)).long() + last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1) + if self.args.use_dynamic_padding: + # get last non-padding position + max_pos = last_one_pos.max().item() + 1 + else: + max_pos = loss_mask.shape[-1] + + if self.args.tokenize_mode == "sst" and self.args.padding_mode == "pack": + # 兼容sst + pack tokenization, 最后一位是脏数据,需要去掉 + result_batch["loss_mask"] = loss_mask.float()[:, 1 : max_pos - 1].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + result_batch["input_ids"] = input_ids[:, : max_pos - 2].contiguous() + result_batch["labels"] = input_ids[:, 1 : max_pos - 1].contiguous() + else: + result_batch["loss_mask"] = loss_mask.float()[:, 1:max_pos].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + # print(f"shape of input_ids: {input_ids.shape}") + result_batch["input_ids"] = input_ids[:, : max_pos - 1].contiguous() + result_batch["labels"] = input_ids[:, 1:max_pos].contiguous() + + # Get the masks and position ids. + + # if you want to be compatible with non-gpt models, something you can do here + if self.args.model_type in ["antglm"]: + (result_batch["attention_mask"], result_batch["position_ids"]) = get_attention_mask_and_position_ids( + data=result_batch["input_ids"] + ) + elif self.args.model_type in ["mixtral", "mtx-qwen2", "qwen2_moe"]: + batch_size, seq_length = result_batch["input_ids"].shape + # bsz * seq_length + range_tensor = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1) + # attention_mask for padding tokens + attention_mask = (range_tensor <= last_one_pos.reshape(batch_size, 1)).long() + result_batch["attention_mask"], result_batch["position_ids"] = attention_mask, None + else: + # For decoder-only models, transformers will create them. + result_batch["attention_mask"], result_batch["position_ids"] = None, None + + if task_id is not None: + task_id = torch.tensor(np.array(task_id)) + result_batch["task_mask"] = get_task_mask(self.args, task_id) # bsz * task_num + result_batch["task_id"] = task_id + + return result_batch + + +def pprint_args(args, accelerator): + # 计算所有键的最大字符串长度 + max_key_length = max(len(str(key)) for key in vars(args).keys()) + + message = "" + message += "====" * 60 + "\n" + message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n" + message += "====" * 60 + "\n" + accelerator.print(message) + accelerator.print("GPU: {}".format(torch.cuda.current_device())) + + +def prepare_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_config", type=str, default=None) + + parser.add_argument("--data_paths", type=str, default=None) + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--tb_dir", type=str, default=None) + parser.add_argument("--pretrained_model_path", type=str, default=None) + parser.add_argument("--micro_batch_size", type=int, default=None) + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--distributed_type", type=str, default="deepspeed") + + parsed = parser.parse_args() + # get json configs + with open(parsed.train_config, "r") as f: + train_config = json.load(f) + + # parse args from cofig.json + # args = argparse.Namespace(**train_config) + args = MptTrainArgs(**train_config) + + # override args by cli arguments + if parsed.data_paths: + args.data_paths = parsed.data_paths + if parsed.output_dir: + args.output_dir = parsed.output_dir + if parsed.tb_dir: + args.tb_dir = parsed.tb_dir + if parsed.pretrained_model_path: + args.pretrained_model_path = parsed.pretrained_model_path + args.vocab_file = parsed.pretrained_model_path + if parsed.micro_batch_size: + args.per_device_train_batch_size = parsed.micro_batch_size + args.per_device_eval_batch_size = parsed.micro_batch_size + if parsed.model_type: + args.model_type = parsed.model_type + + args.distributed_type = parsed.distributed_type + + # refactor args + + args.vocab_file = args.pretrained_model_path + + args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(","))) + "]" + + # generate TASK2ID, ID2TASK + generate_task_id(args.data_paths) + + if args.weighted_loss_mode == "coba": + args.task_weights = [1.0] * len(ID2TASK) + elif args.task_weights is not None: + args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")] + assert len(args.task_weights) == len(ID2TASK), f"length of task_weights must equal to length of data_paths" + else: + args.task_weights = [1.0] * len(ID2TASK) + + return args + + +def main(): + t0 = time.time() + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["HF_HUB_OFFLINE"] = "false" + # get input args, set TASK2ID, ID2TASK, refactor args + args = prepare_args() + + # fix randomness + if args.seed is not None: + set_seed(args.seed) + + # define accelerator + init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) + + if args.distributed_type and args.distributed_type.lower() == "fsdp": + fsdp_plugin = FullyShardedDataParallelPlugin( + # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + limit_all_gathers=True, + sync_module_states=True, + use_orig_params=True, + cpu_offload=False, + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + else: + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + + # print key infos + accelerator.print("In mft_accelerate.py, sys path:", sys.path) + accelerator.print(f"transformers.__version__: {transformers.__version__}") + + # get world_size + args.world_size = accelerator.num_processes + + # backup args + pprint_args(args, accelerator) + if accelerator.is_main_process: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "args.json"), "w") as f: + json.dump(args.dict(), f, indent=2) + + # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest + latest = None + if os.path.exists(os.path.join(args.output_dir, "latest")): + with open(os.path.join(args.output_dir, "latest"), "r") as fl: + latest = json.load(fl) + accelerator.print(f"[INFO] Existing latest: {latest}") + + if args.auto_resume and args.resume_from_checkpoint is None and latest: + args.resume_from_checkpoint = latest["latest_ckpt"] + + # logger + logging.basicConfig( + format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + # compile Cpp helper + compile_helper() + time.sleep(10) + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # get global_rank and local rank for current process + global_rank = accelerator.process_index + local_rank = accelerator.local_process_index + print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}") + + # TASK2ID, ID2TASK + # generate_task_id(args.data_paths) + + # multi task blendable dataset(sharded) + if args.load_raw_dataset: + print_rank_0("> load raw jsonl dataset") + train_dataset, valid_dataset = load_dataset_from_jsonl( + args=args, shard_data=True, world_size=args.world_size, global_rank=global_rank, local_rank=local_rank + ) + else: + print_rank_0("> load tokenized bin dataset, refer to gpt_neox indexed dataset") + train_dataset, valid_dataset, _ = load_dataset_from_bin(args=args) + + t1 = time.time() + logger.info(f"dataset loading time: {t1 - t0:.4f}") + + # cuda memory + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + max_memory = f"{free_in_GB - 2}GB" + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory for i in range(n_gpus)} + accelerator.print("max memory: ", max_memory, n_gpus) + + # # 是否要加入新的special tokens + # num_added_toks = tokenizer.tokenizer.add_special_tokens(["", ""]) + # accelerator.print("We have added", num_added_toks, "tokens") + # accelerator.print(f"role marker tokens {tokenizer.convert_tokens_to_ids('')} {tokenizer.convert_tokens_to_ids('')}, resized tokenizer_size: {len(tokenizer)}") + + # creating model + ModelClass = MODEL_TYPES[args.model_type] + if args.model_type in SUPPORT_IN_TRANSFORMERS: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + attn_implementation=args.attn_implementation, + torch_dtype=torch.bfloat16, + ) + else: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + torch_dtype=torch.bfloat16, + ) + + # build a tokenizer for possible resizing or saving + tokenizer = build_tokenizer(args) + # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, + # i.e. the length of the tokenizer. + # 如果新增special tokens, 需要resize input embedding 和output embedding + # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) + + model.gradient_checkpointing_enable() + + if args.saving_limit is None or not isinstance(args.saving_limit, int) or args.saving_limit < 1: + # saving_limit is set automatically if needed + args.saving_limit = 2 + accelerator.print( + "[WARNING]saving_limit must be a integer greater than 1 in Full-Parameters Training, we set it to 2" + ) + + t2 = time.time() + if accelerator.is_main_process: + logging.info(f"model loading time: {t2 - t1:.4f}") + + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + if hasattr(model.config, "use_logn_attn"): + model.config.use_logn_attn = False # special for qwen model + # load balance for moe training + if hasattr(model.config, "output_router_logits"): + model.config.output_router_logits = True + model_config = model.config + accelerator.print(model.config) + + # dataloader + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_train_batch_size, + pin_memory=True, + drop_last=True, + ) + if valid_dataset: + valid_dataloader = DataLoader( + valid_dataset, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_eval_batch_size, + pin_memory=True, + drop_last=True, + ) + else: + valid_dataloader = None + + # optimizer + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED") + # from deepspeed.ops.adam import FusedAdam as Adam + # adam_optimizer = Adam + adam_optimizer = torch.optim.AdamW + elif accelerator.distributed_type == DistributedType.FSDP: + accelerator.print("DISTRIBUTED TRAINING USING FSDP") + model = accelerator.prepare(model) + adam_optimizer = torch.optim.AdamW + else: + raise ValueError("Only support DeepSpeed and FSDP") + + optimizer = adam_optimizer( + model.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.999), + ) + # for group in optimizer.param_groups: + # group.setdefault("initial_lr", group["lr"]) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + if isinstance(args.num_warmup_steps, float) and args.num_warmup_steps < 1.0: + args.num_warmup_steps = int(args.max_train_steps * args.num_warmup_steps) // accelerator.num_processes + accelerator.print(f"num_warmup_steps: {args.num_warmup_steps}") + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + # scheduler_specific_kwargs={"last_epoch": scheduler_last_ep} + ) + # prepare all + if accelerator.distributed_type == DistributedType.DEEPSPEED: + if valid_dataloader: + (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, valid_dataloader, optimizer, lr_scheduler + ) + else: + (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, optimizer, lr_scheduler + ) + + # prepare all except model, which is prepared before + elif accelerator.distributed_type == DistributedType.FSDP: + if valid_dataloader: + (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, valid_dataloader, lr_scheduler + ) + else: + (optimizer, train_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) + print(model.device) + accelerator.print(model) + # accelerator.print(model.config) + + # Recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterward we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # zero 3 flag + is_ds_zero_3 = False + if getattr(accelerator.state, "deepspeed_plugin", None): + is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3 + accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}") + elif getattr(accelerator.state, "fsdp_plugin", None): + accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}") + + trainer = MptTrainer( + accelerator=accelerator, + model=model, + model_config=model_config, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + num_update_steps_per_epoch=num_update_steps_per_epoch, + total_train_dataset_size=len(train_dataset), + args=args, + ) + trainer.accelerate_train() + + +if __name__ == "__main__": + main() diff --git a/mftcoder_accelerate/src/mpt/mpt_arguments.py b/mftcoder_accelerate/src/mpt/mpt_arguments.py new file mode 100644 index 0000000..8045421 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_arguments.py @@ -0,0 +1,161 @@ +""" +# @author Chaoyu Chen +# @date 2024/6/1 + +MPT training arguments +""" + +from dataclasses import dataclass, asdict +from typing import List, Union + + +@dataclass +class MptTrainArgs: + # train data paths on shared FS + data_paths: Union[str, List[str]] + + # output dir for saving adaptors in peft or full ckpts in full-parameter training + output_dir: str + + # tensorboard dir for saving tensorboard logs + tb_dir: str + + # pretrained_model_path, on which is the model you want to train + pretrained_model_path: str + + # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2 + model_type: str + + # load from raw jsonl file or tokenized binary file + load_raw_dataset: bool = True + + # weights of loss calculation for each task, None means equal weights + task_weights: Union[None, str] = None + + # weights of data sampling, leave it None + data_weights: Union[None, str] = None + + # hf loading model low_cpu_mem_usage + low_cpu_mem_usage: bool = True + + # train/valid/test split + data_split: str = "98,2,0" + + # padding or pack or concat + padding_mode: str = "padding" + + # sft or sst + tokenize_mode: str = "sft" + + # case3 or case4 + weighted_loss_mode: str = "case3" + + # mircro train batch size + per_device_train_batch_size: int = 8 + + # micro eval batch size, always same as micro train batch size + per_device_eval_batch_size: int = 8 + + # HF AutoTokenizer is supported, maybe more types + tokenizer_type: str = "AutoTokenizer" + + # initial lr + learning_rate: float = 5e-5 + + # minimum lr + min_lr: float = 5e-6 + + # weight decay + weight_decay: float = 0.01 + + # gradient_accumulation_steps + gradient_accumulation_steps: int = 1 + + # lr_scheduler_type + lr_scheduler_type: str = "cosine" + + # num_warmup_steps + num_warmup_steps: Union[int, float] = 0.05 + + # num_train_epochs + num_train_epochs: int = 4 + + # seed for reproducing + seed: int = 1234 + + # seq_length, context length + seq_length: int = 4096 + + # path of adaptor which is resumed from, None for not resuming training + resume_from_checkpoint: Union[None, str] = None + + # auto resume from latest ckpt if job restarted + auto_resume: bool = True + + # num of steps for logging training loss + log_interval: int = 10 + + # num of steps for saving ckpt + checkpointing_steps: int = 100 + + # num of steps for evaluation(eval_loss), better same as checkpointing steps + evaluation_steps: int = 100 + + # max train steps, if None, depends on num_train_epochs + max_train_steps: Union[None, int] = None + + # if checkpointing every epoch, maybe True in sst + epoch_checkpointing: bool = False + + # save transformers model(safetensors) + save_transformers_model: bool = False + + # shuffle before train/valid split + shuffle_before_split: bool = True + + # DDP random sampler + use_random_sampler: bool = True + + # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point + early_stopping: bool = True + early_stopping_stall_num: int = 5 + + # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota. + saving_limit: Union[None, int] = None + + # if dynamic padding + use_dynamic_padding: bool = True + + # warm-up steps for CoBa, recommand the number of valid batches + coba_warmup_steps: int = 100 + # history length of sample valid loss used to fit the slope curve in CoBa, recommand [2*coba_warmup_steps,5*coba_warmup_steps] + coba_history_length: int = 200 + # temperature for divergence factor in CoBa + coba_tau: int = 5 + # iteration interval of update per task train weight in CoBa + coba_update_interval: int = 1 + # the number of mini valid batches sampled at each updated iteration interval + coba_sample_valid_num: int = 1 + + # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} + attn_implementation: str = "flash_attention_2" + + # role markers, which are prompt template before each role: system, user and assistant + # role_markers: {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} + role_markers: Union[None, dict] = None + + distributed_type: Union[None, str] = None + + init_timeout_seconds: Union[None, int] = 3600 + + # legacy, leave them + use_xformers: bool = True + trust_remote_code: bool = True + weight_by_num_documents: bool = True + make_vocab_size_divisible_by: int = 32 + model_parallel_size: int = 1 + use_slow_tokenizer: bool = False + world_size: int = 8 + + def dict(self): + return {k: str(v) for k, v in asdict(self).items()} diff --git a/mftcoder_accelerate/src/mpt/mpt_trainer.py b/mftcoder_accelerate/src/mpt/mpt_trainer.py new file mode 100644 index 0000000..b5e2da8 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_trainer.py @@ -0,0 +1,606 @@ +""" +# @author qumu +# @date 2024/6/6 +# @module mpt_trainer.py + +MPT/MCT/MFT Full-parameter Trainer +""" + +import gc +import os +import sys +import threading +import argparse +import math +import logging +import json +import time +import transformers +import numpy as np +import psutil +import shutil +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from typing import List, Optional, Tuple, Union +from tqdm.auto import tqdm +from accelerate.logging import get_logger +from accelerate import Accelerator +from transformers import set_seed + +# sys.path.append("..") +from utils.common_utils import generate_task_id, TASK2ID, ID2TASK +from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func + +logger = get_logger(__name__) + + +def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): + # create path if not exist + if not os.path.exists(save_path): + os.makedirs(save_path) + + # copy each file in files_list to save_path + for filename in files_list: + src_file = os.path.join(mode_path, filename) + + # copy only if src exists + if os.path.exists(src_file): + dest_file = os.path.join(save_path, filename) + + # copy + shutil.copy(src_file, dest_file) + print(f"Copied {filename} to {save_path}") + else: + print(f"File {filename} does not exist in {mode_path}") + + +def check_existing_ckpts(output_dir): + prefix = "step_" + + if not os.path.exists(output_dir): + return [] + # list all files and dirs + contents = os.listdir(output_dir) + + # find dirs starts with "step_" + matching_folders = [ + folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix) + ] + + return matching_folders + + +def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps): + """ + extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training + """ + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + logger.info(f"Resume from exact Epoch {starting_epoch}: completed_steps {completed_steps}") + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + completed_steps = int(training_difference.replace("step_", "")) + starting_epoch = completed_steps // num_update_steps_per_epoch + resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps + logger.info(f"Resume from Epoch {starting_epoch} + step {resume_step}: completed_steps {completed_steps}") + + return starting_epoch, completed_steps, resume_step + + +def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): + for key, value in log_dict.items(): + summary_writer.add_scalar(f"{key}", value, completed_steps) + + +def delete_ckpts_over_limits(output_dir, saving_limit, best_step): + """delete ckpts more than saving_limits except for the best_step ckpt""" + existing_ckpts = check_existing_ckpts(output_dir) + logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}") + # sorted only step num ascendingly + ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts]) + # delete the oldest steps except for the best step at present + if len(ckpt_steps)> saving_limit: + deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step] + # print(deletable_steps[:len(ckpt_steps) - saving_limit]) + for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]: + shutil.rmtree(os.path.join(output_dir, f"step_{del_step}")) + logger.info(f"Removed ckpt step_{del_step}") + + +class MptTrainer: + """ + Multitask Pre-train/Continue-train Trainer with Full-parameters training. + """ + + def __init__( + self, + accelerator: Accelerator, + model, + model_config, + train_dataloader, + valid_dataloader, + optimizer, + lr_scheduler, + tokenizer, + num_update_steps_per_epoch, + total_train_dataset_size, + args, + ): + self.accelerator = accelerator + self.model = model + # hf model config + self.model_config = model_config + self.train_dataloader = train_dataloader + self.valid_dataloader = valid_dataloader + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.tokenizer = tokenizer + self.num_update_steps_per_epoch = num_update_steps_per_epoch + self.total_train_dataset_size = total_train_dataset_size + # training arguments + self.args = args + # tensorboard writer + self.summary_writer = SummaryWriter(log_dir=args.tb_dir) + + def print(self, msg: str): + """ + accelerator print, default on main process + Args: + msg: + + Returns: + + """ + self.accelerator.print(msg) + + def touch(self, batch, num_tokens=10): + """touch first and last tokens and labels for debugging usage""" + self.print( + f"step 1 batch shape: {batch['input_ids'].shape},\n" + f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}" + f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}" + ) + self.print(f"first {num_tokens} input_ids and loss_mask") + for pt in range(1): + self.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + self.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + + @staticmethod + def format_tensor(tensor, n): + return list(map(lambda x: round(x, n), tensor.tolist())) + + def accelerate_saving_states(self, output_dir: str, completed_steps: int): + """ + Saving lora adaptor or full checkpoint using accelerator + Args: + output_dir: exact dir for saving ckpt + completed_steps: + + Returns: + + """ + self.accelerator.wait_for_everyone() + logger.info(f"[CHECKPOINT] Saving checkpoint states") + self.accelerator.save_state(output_dir) + self.accelerator.wait_for_everyone() + + # save safetensors for direct inference if needed + if self.args.save_transformers_model: + logger.info(f"[CHECKPOINT] Saving transformers(hf) model", main_process_only=True) + unwrapped_model = self.accelerator.unwrap_model(self.model) + # self.print(f"unwrapped model type {type(unwrapped_model)}") + unwrapped_model.save_pretrained( + output_dir, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, + state_dict=self.accelerator.get_state_dict(self.model), + ) + self.accelerator.wait_for_everyone() + + # tokenizer saving and bug dummy ckpt cleaning. + if self.accelerator.is_main_process: + if self.args.model_type.lower() == "deepseek": + copy_tokenizer_files( + self.args.pretrained_model_path, ["tokenizer.json", "tokenizer_config.json"], output_dir + ) + else: + self.tokenizer.save_pretrained(output_dir) + + sf = os.path.join(output_dir, "model.safetensors") + index_file = os.path.join(output_dir, "model.safetensors.index.json") + if os.path.isfile(sf) and os.path.isfile(index_file): + self.print(f"Remove bug dummy ckpt {sf}") + os.remove(sf) + + # save latest info + if self.accelerator.is_main_process: + latest = { + "latest_ckpt": output_dir, + "lr": self.optimizer.param_groups[0]["lr"], + } + with open(os.path.join(self.args.output_dir, "latest"), "w") as f: + json.dump(latest, f, indent=2) + + logger.info( + f"[CHECKPOINT][complete_steps={completed_steps}], states {output_dir} saved, latest: {latest}", + main_process_only=True, + ) + self.accelerator.wait_for_everyone() + + def accelerate_monitor( + self, + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status=None, + ): + """ + gather reduce_loss and reduce_task_loss from all N devices. + train logging and tensorboarding. + """ + # gather reduce_loss and reduce_task_loss from all N devices + reduce_losses = self.accelerator.gather(reduce_loss).detach().float() + reduce_task_losses = self.accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK)) + reduce_task_exists = self.accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK)) + + # get train loss and per-task train loss + train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps) + # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (self.args.log_interval * self.args.gradient_accumulation_steps) + train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0) + + # logging and writing tensorboard + logger.info( + f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]" + f"[train_task_loss={self.format_tensor(train_task_loss, 4)}]" + f"[gather shape={list(reduce_losses.shape)}]" + f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]", + main_process_only=True, + ) + if coba_status is not None: + if completed_steps> coba_status.coba_warmup_steps: + coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum( + coba_status.log_per_task_weight + ) + else: + coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) + logger.info( + f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True + ) + train_log_dict = {"Loss/train": train_loss} + for i in range(len(ID2TASK)): + train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i] + if coba_status is not None: + train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item() + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, train_log_dict, completed_steps) + + if coba_status is not None: + coba_status.log_per_task_weight = torch.zeros(len(ID2TASK)) + + def accelerate_evaluate( + self, + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ): + """ + evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices. + eval logging and tensorboarding. + """ + losses = [] + accumulated_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + accumulated_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + for valid_step, valid_batch in enumerate(self.valid_dataloader): + with torch.no_grad(): + outputs = self.model( + input_ids=valid_batch["input_ids"], + attention_mask=valid_batch["attention_mask"], + position_ids=valid_batch["position_ids"], + return_dict=True, + ) + + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=valid_batch["labels"], + task_mask=valid_batch["task_mask"], + task_id=valid_batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=valid_batch["loss_mask"], + task_weights=self.args.task_weights, + ) + + losses.append(self.accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size))) + accumulated_task_loss += task_loss.detach().float() + accumulated_task_exist += (task_loss != 0.0).detach().float() + + self.accelerator.wait_for_everyone() + valid_batch_num = len(losses) + gathered_size = losses[0].shape + losses = torch.cat(losses) + # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK)) + task_losses = self.accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK)) + task_exists = self.accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK)) + + try: + eval_loss = torch.mean(losses) + # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num + eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0) + if eval_loss <= min_eval_loss: + min_eval_loss = eval_loss + stall_num = 0 + best_step = completed_steps + else: + stall_num += 1 + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info( + f"[EVAL][completed_steps={completed_steps}]" + f"[eval_loss={eval_loss:.6f}][eval_task_loss={self.format_tensor(eval_task_loss, 4)}]" + f"[perplexity={perplexity:.4f}][valid_batch_num={valid_batch_num}]" + f"[gather_size={list(gathered_size)}]", + main_process_only=True, + ) + eval_log_dict = { + "Loss/valid": eval_loss, + "Perplexity/valid": perplexity, + "Epochs": round(completed_steps / self.num_update_steps_per_epoch, 2), + } + for i in range(len(ID2TASK)): + eval_log_dict[f"{ID2TASK[i]}_loss/valid"] = eval_task_loss[i] + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, eval_log_dict, completed_steps) + + return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step + + def accelerate_train(self): + # Train! + if self.args.seed is not None: + set_seed(self.args.seed) + + global_batch_size = ( + self.args.per_device_train_batch_size + * self.accelerator.num_processes + * self.args.gradient_accumulation_steps + ) + logger.info("************************************** Running training ****************************************") + logger.info(f" Num examples = {self.total_train_dataset_size}") + logger.info(f" Num Epochs = {self.args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") + logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}") + logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") + logger.info(f" Total optimization(update/completed) steps = {self.args.max_train_steps}") + logger.info(f" Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_train_epochs}") + logger.info("************************************************************************************************") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process) + + # set starting_epoch, completed_steps and resume_step of train_dataloader + completed_steps = 0 + starting_epoch = 0 + resume_step = None + + if self.args.resume_from_checkpoint: + self.accelerator.load_state(self.args.resume_from_checkpoint) + self.accelerator.print(f"Resumed from checkpoint: {self.args.resume_from_checkpoint}") + path = os.path.basename(self.args.resume_from_checkpoint) + starting_epoch, completed_steps, resume_step = extract_epochs_and_steps( + path, self.num_update_steps_per_epoch, self.args.gradient_accumulation_steps + ) + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + # monitor minimum eval_loss, stalling num, and best_step + min_eval_loss = float("inf") + stall_num = 0 + best_step = None + + # monitor train loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + per_task_weight = self.args.task_weights + + if self.args.weighted_loss_mode == "coba": + self.model.eval() + eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate( + completed_steps, + 0, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + coba_status = CoBaStatus( + self.args.coba_warmup_steps, + self.args.coba_history_length, + self.args.coba_tau, + self.args.coba_update_interval, + self.args.coba_sample_valid_num, + self.valid_dataloader, + ) + coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device) + coba_status.sample_valid_batch(self.model, completed_steps) + logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + else: + coba_status = None + + # Training Loop! + for epoch in range(starting_epoch, self.args.num_train_epochs): + # set_epoch + # self.train_dataloader.set_epoch(epoch) + + # if we early stop by some ckpts not converging + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + break + + if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step) + else: + active_dataloader = self.train_dataloader + tail_num = len(active_dataloader) - len(active_dataloader) % self.args.gradient_accumulation_steps + print(f"length of dataloader: {len(active_dataloader)}") + + self.model.train() + # Inner Loop! + for step, batch in enumerate(active_dataloader): + if step == tail_num: + break + with self.accelerator.accumulate(self.model): + if step == 0: + self.touch(batch, num_tokens=10) + # forward + outputs = self.model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + position_ids=batch["position_ids"], + return_dict=True, + ) + + if ( + self.args.weighted_loss_mode == "coba" + and self.accelerator.sync_gradients + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps>= self.args.coba_warmup_steps + ): + with torch.no_grad(): + per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps) + coba_status.log_per_task_weight += per_task_weight + # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True) + + # loss + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=batch["labels"], + task_mask=batch["task_mask"], + task_id=batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=batch["loss_mask"], + task_weights=per_task_weight, + ) + + # accelerator.print(len(outputs.router_logits), outputs.router_logits[0], outputs.router_logits[-1]) + # accelerator.print(batch['attention_mask'].shape, batch['attention_mask']) + aux_loss = None + if hasattr(self.model_config, "output_router_logits") and self.model_config.output_router_logits: + if hasattr(self.model_config, "num_local_experts"): + num_experts = self.model_config.num_local_experts + elif hasattr(self.model_config, "num_experts"): + num_experts = self.model_config.num_experts + else: + raise ValueError("model has no attribute num_local_experts or num_experts") + aux_loss = load_balancing_loss_func( + outputs.router_logits, + num_experts, + self.model_config.num_experts_per_tok, + batch["attention_mask"], + ) + aux_loss = self.model_config.router_aux_loss_coef * aux_loss.to(loss.device) + loss += aux_loss # make sure to reside in the same device + + # backward + self.accelerator.backward(loss) + # print(self.lr_scheduler.state_dict(), self.accelerator.process_index) + # update(sync_gradients) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + # support args.min_lr + if self.optimizer.param_groups[0]["lr"] <= self.args.min_lr: + self.optimizer.param_groups[0]["lr"] = self.args.min_lr + + # accumulate resuce_loss and reduce_task_loss in a log_interval + if not torch.isnan(loss): + reduce_loss += loss.detach().float() + if aux_loss and not torch.isnan(aux_loss): + reduce_aux_loss += aux_loss.detach().float() + # self.print("task loss devices: ", reduce_task_loss.device, task_loss.device) + reduce_task_loss += task_loss.detach().float() + reduce_task_exist += (task_loss != 0).detach().float() + + # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done. + if self.accelerator.sync_gradients: + if ( + self.args.weighted_loss_mode == "coba" + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps>= 1 + ): + coba_status.sample_valid_batch(self.model, completed_steps) + # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + + # progress_bar.update(1) + completed_steps += 1 + # monitoring training process and logging and tensorboarding + if completed_steps % self.args.log_interval == 0: + progress_bar.update(self.args.log_interval) + if reduce_aux_loss> 0.0: + self.print(f"[INFO] aux_loss: {reduce_aux_loss/self.args.log_interval}") + self.accelerate_monitor( + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status, + ) + # reset reduce_loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + + # steps checkpointing + if self.args.checkpointing_steps and completed_steps % self.args.checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_states(output_dir, completed_steps) + + # steps evaluation + if completed_steps % self.args.evaluation_steps == 0 and self.valid_dataloader: + self.model.eval() + eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate( + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + + # delete ckpts over args.saving_limit + if self.accelerator.is_main_process and self.args.saving_limit: + delete_ckpts_over_limits(self.args.output_dir, self.args.saving_limit, best_step) + + # early stoppin when stalling more than args.early_stopping_stall_num + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + self.print(f"[WARNING] Early stopping at {completed_steps}") + break + + if completed_steps>= self.args.max_train_steps: + break + self.accelerator.wait_for_everyone() + + # epoch checkpointing + if self.args.epoch_checkpointing: + output_dir = f"epoch_{epoch + 1}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_states(output_dir, completed_steps) + + self.summary_writer.close() diff --git a/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py new file mode 100644 index 0000000..ca4347e --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +import argparse +import multiprocessing +import os +import sys +import random +import time +import tqdm +import glob +import json +import numpy as np + + +# 将父目录的父目录加入path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +grandparent_dir = os.path.dirname(parent_dir) +sys.path.append(grandparent_dir) + +from tokenizer import init_tokenizer +from pack_encoder import PackSSTBinEncoder, load_tokenizer +from data import indexed_dataset + +from threading import Semaphore +from colorama import Fore +import lm_fmt as lmd + + +def yield_from_files(files: list, semaphore): + """ + Iterator over input documents + + :param fnames: list of filenames + """ + def yielder(fname, semaphore): + with open(fname, 'r') as f: + for line in f: + semaphore.acquire() + yield json.loads(line) + + for fname in files: + semaphore.acquire() + yield from yielder(fname, semaphore) + +def yield_from_files2(fnames: list, semaphore, sample_percent): + """ + Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts / + other compressed formats. Also filters out empty documents. + + :param fnames: list of filenames + """ + def yielder(fname, semaphore): + try: + sample_interval = int(1/sample_percent) + for f in filter(lambda x: x, lmd.Reader(fname).stream_data(key=None)): + rand_value = random.randint(1, sample_interval*100) + if rand_value % sample_interval != 0: + continue + semaphore.acquire() + + #rand_value = random.randint(1, sample_interval*100) + #if rand_value % sample_interval != 0: + # yield None + + yield f + except Exception as e: + print('####Exception:', e.args) + yield None + + for fname in fnames: + semaphore.acquire() + + yield from yielder(fname, semaphore) + + +def print_example_doc(input_ids, tokenizer): + print(Fore.YELLOW + f'INPUT IDS len: {len(input_ids)}') + print(Fore.BLUE + f'INPUT IDS:\n {input_ids}\n\n') + + print(Fore.RED + f'DETOKENIZED INPUT:\n{tokenizer.decode(input_ids)}') + + +def core_process(encoded_docs, semaphore, seq_length, tokenizer, encoder, builder, output_idx_file): + """ + core of Data Pack SFT processing + """ + input_ids_key = 'input_ids' + + proc_start = time.time() + total_bytes_processed = 0 + pbar = tqdm.tqdm() + sentence_droped = 0 + loss_token_cnt = 0 + + print("PRINT BEFORE STREAM PROCESS DATA") + + print_example_count = 0 + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + + # release semaphore so `yield_from_files` can add another file to the buffer + semaphore.release() + + # add each tokenized document / sentence, + # For sft, each document has only one sample + input_ids_sentence = doc[input_ids_key][0] + if len(input_ids_sentence) < 1: + sentence_droped += 1 + continue + + builder.add_item(np.array(input_ids_sentence, dtype=builder.dtype)) + builder.end_document() + #builder.finalize_without_close(output_idx_file) + #builder.add_item_and_end_document_and_finalize(np.array(input_ids_sentence, dtype=builder.dtype), output_idx_file) + + # print the first packed sample as example + if print_example_count < 1: + print_example_doc(input_ids_sentence, tokenizer) + print_example_count += 1 + + # log progress + if i % 100 == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + pbar.set_description( + f"Processed {i} documents ({i / elapsed} docs/s, {mbs} MB/s)." + ) + if i != 0: + pbar.update(100) + + # 尾部处理 + builder.finalize(output_idx_file) + + print(Fore.RED + "\ndroped docs: {}".format(sentence_droped)) + + +def process_dataset(dataset_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent): + """ + Re-organize samples in the given data path into a Data Pack file. + """ + + # get all jsonl files and corresponding reading handler + files = glob.glob(os.path.join(dataset_path, '**/*.jsonl'), recursive=True) + + # build a semaphore object to stop `yield_from_files` from getting ahead + # of encoder.encode and hence building up memory + semaphore = Semaphore(1000 + parallel_num) + + # build sample iterator + sample_iterator = yield_from_files2(files, semaphore, sample_percent) + + # load tokenizer + # tokenizer = load_tokenizer(model_path, tokenizer_type) + tokenizer = init_tokenizer(model_path) + print('TOKEN of id=2:', tokenizer.convert_ids_to_tokens(2)) + print('ID of :', tokenizer.convert_tokens_to_ids('')) + print('TOKEN of id=0:', tokenizer.convert_ids_to_tokens(0)) + print('ID of :', tokenizer.convert_tokens_to_ids('')) + + # init encoder + encoder = PackSSTBinEncoder(seq_length, model_path) + + # create writer builder + key = "input_ids" + output_prefix = os.path.join(output_path, dataset_name) + output_bin_file = "{}_{}.bin".format( + output_prefix, key + ) + output_idx_file = "{}_{}.idx".format( + output_prefix, key + ) + builder = indexed_dataset.make_builder( + output_bin_file, + impl="mmap", + vocab_size=tokenizer.vocab_size, + ) + + if parallel_num> 1: + pool = multiprocessing.Pool(parallel_num, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, sample_iterator, chunksize=32) + else: + encoder.initializer() + encoded_docs = (encoder.encode(doc) for doc in sample_iterator) + + if dataset_name is None: + dataset_path = dataset_path[:-1] if dataset_path.endswith(os.path.sep) else dataset_path + dataset_name = dataset_path.split(os.path.sep)[-1] + + core_process(encoded_docs, semaphore, seq_length, tokenizer, encoder, builder, output_idx_file) + + +def main(data_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent): + """ + Entry + """ + + process_dataset(data_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a packed jsonl file in the Data Pack SFT way.") + parser.add_argument('--model-path', type=str, help='Path of a pretrained model which contains tokenizer-related files.') + parser.add_argument('--parallel', type=int, default=1, help='The num of parallel processing.') + parser.add_argument('--output-path', type=str, help='Path to store the genered result file.') + parser.add_argument('--data-path', type=str, default=None, help='Path of files to be processed') + parser.add_argument('--seq-length', type=int, default=4096, help='The max input length (i.e. the max number of tokens in a sample)') + # parser.add_argument('--eod-token-id', type=int, default=2, help='EOD token id') + # parser.add_argument('--pad-token-id', type=int, default=0, help='PAD token id') + # parser.add_argument('--tokenizer-type', type=str, choices=["LLAMATokenizer", None], default=None, help="What type of tokenizer to use. Default is None.") + parser.add_argument('--dataset-name', type=str, default=None, help='The generated result dataset name. The folder name will be token by default.') + parser.add_argument('--sample-percent', type=float, default=1.0, help='Sample percentage') + + args = parser.parse_args() + print('ARGS\n', '\n'.join([str(key) + ':' + str(value) for key,value in vars(args).items()])) + + random.seed(9999) + + main(args.data_path, args.output_path, args.model_path, args.parallel, args.seq_length, args.dataset_name, args.sample_percent) diff --git a/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py b/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py new file mode 100644 index 0000000..c922859 --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py @@ -0,0 +1,360 @@ +import os +import zstandard +import ujson as json +import time +import tarfile +import codecs +from functools import reduce +import jsonlines +import io +from zipfile import ZipFile +import gzip +from math import ceil +import mmap +import multiprocessing as mp +from pathlib import Path + +VALID_EXTENSIONS = ['openwebtext.tar.xz', '_data.xz', '.dat.zst', '.jsonl', '.jsonl.zst', '.jsonl.zst.tar', '.json.zst', '.txt', '.zip', '.tar.gz', '.json.gz', '.gz'] + +def has_valid_extension(file): + return any([file.endswith(ext) for ext in VALID_EXTENSIONS]) + +def _listdir_or_file(x): + if isinstance(x, list): + return reduce(lambda x, y: x + y, map(listdir_or_file, sorted(x))) + if os.path.isfile(x): + return [x] + elif os.path.isdir(x): + return [str(Path(x) / fn) for fn in sorted(os.listdir(x))] + else: + raise FileNotFoundError(f"{x} not found") + +def listdir_or_file(x): + return list(filter(has_valid_extension, _listdir_or_file(x))) + +def tarfile_reader(file, streaming=False): + # we need our own tarfile parser because `tarfile` doesn't work well for + # big tarfiles; it seems to be reading the entire file to get a list of + # where all the files are - but we don't need that because we just need + # to see each file once. surprisingly, `tarfile` doesn't expose any + # facilities for this. the only options are 1. load the entire tarfile + # and then query by filename or 2. extract to disk - and neither of + # these is what we want. + + offset = 0 + paxfilesize = None + while True: + hdr = file.read(512) + offset += 512 + + # https://www.gnu.org/software/tar/manual/html_node/Standard.html + # end at 135 not 136 because of 0円 terminator + if hdr[124:135] == b'0円'*11: + # end of record + break + + fname = hdr[:100].split(b'0円')[0] + + # if the file is too big to fit in the size field, tarfiles will actually + # include a PaxHeader with the size in it, applicable to the immediate next file. + if paxfilesize is not None: + size = paxfilesize + paxfilesize = None + else: + size = int(hdr[124:135], 8) + + padded_size = ceil(size / 512) * 512 + + # for handling PaxHeader files (which contain extra metadata about file size) and directories + # https://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_03 + type = chr(hdr[156]) + + if type == 'x': + meta = file.read(padded_size)[:size] + def kv(x): + return x.decode('utf-8').split(' ')[1].split('=') + paxfileattrs = { + kv(x)[0]: kv(x)[1] + for x in meta.split(b'\n') if x + } + paxfilesize = int(paxfileattrs['size']) + + offset += padded_size + continue + elif type != '0' and type != '0円': + if streaming: + file.seek(padded_size, os.SEEK_CUR) + else: + file.read(padded_size) + offset += padded_size + continue + + if streaming: + # skip directory entries + if size != 0: + mmo = mmap.mmap(file.fileno(), length=offset + size, access=mmap.ACCESS_READ) + mmo.seek(offset) + yield mmo + + file.seek(padded_size, os.SEEK_CUR) + else: + yield file.read(padded_size)[:size] + offset += padded_size + +def handle_jsonl(jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key='text'): + for ob in jsonl_reader: + # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility. + if isinstance(ob, str): + assert not get_meta + yield ob + continue + + if key is None: + yield ob + continue + + text = ob[key] + + if autojoin_paragraphs and isinstance(text, list): + text = para_joiner.join(text) + + if get_meta: + yield text, (ob['meta'] if 'meta' in ob else {}) + else: + yield text + + +class Reader: + def __init__(self, in_path): + self.in_path = in_path + + def stream_data(self, get_meta=False, threaded=False, key=None): + if not threaded: + yield from self._stream_data(get_meta, key=key) + return + + q = mp.Queue(1000) + p = mp.Process(target=self._stream_data_threaded, args=(q, get_meta), kwargs={"key": key}) + p.start() + while p.is_alive(): + res = q.get() + if res is None: break + yield res + + def _stream_data_threaded(self, q, get_meta=False): + for data in self._stream_data(get_meta): + q.put(data) + q.put(None) + + def _stream_data(self, get_meta=False, key="text"): + self.f_name = "" + files = listdir_or_file(self.in_path) + if not files: + raise FileNotFoundError(f"No valid file(s) found in {self.in_path}") + for f in files: + self.f_name = f + if f == 'openwebtext.tar.xz': + assert not get_meta + + yield from self.read_owt(f) + elif 'urlsf_subset' in f and f.endswith('_data.xz'): + assert not get_meta + + yield from self.read_owt_subset(f) + elif f.endswith('.dat.zst'): + assert not get_meta + + yield from self.read_dat(f) + elif f.endswith('.jsonl'): + yield from self.read_jsonl(f, get_meta, key=key) + elif f.endswith('.jsonl.zst'): + yield from self.read_jsonl_zst(f, get_meta, key=key) + elif f.endswith('.jsonl.zst.tar'): + yield from self.read_jsonl_tar(f, get_meta, key=key) + elif f.endswith('.json.zst'): + assert not get_meta + + yield from self.read_json(f) + elif f.endswith('.txt'): + assert not get_meta + + yield from self.read_txt(f) + elif f.endswith('.zip'): + assert not get_meta + + yield from self.read_zip(f) + elif f.endswith('.tar.gz'): + assert not get_meta + + yield from self.read_tgz(f) + elif f.endswith('.json.gz'): + assert not get_meta + + yield from self.read_jsongz(f) + elif f.endswith('.gz'): + assert not get_meta + + yield from self.read_gz(f) + else: + # shouldn't be reached + print(f'Skipping {f} as streaming for that filetype is not implemented') + + def read_txt(self, file): + with open(file, 'r') as fh: + yield fh.read() + + def read_zip(self, file): + archive = ZipFile(file, 'r') + for f in archive.namelist(): + yield archive.read(f).decode('UTF-8') + + def read_tgz(self, file): + gz = gzip.open(file) + yield from (x.decode('utf-8') for x in tarfile_reader(gz, streaming=False)) + + def read_gz(self, file): + with gzip.open(file, 'rb') as f: + for line in f: + yield line.decode('utf-8') + + def read_jsongz(self, file): + for line in self.read_gz(file): + yield json.loads(line) + + def read_json(self, file): + with open(file, 'rb') as fh: + cctx = zstandard.ZstdDecompressor() + reader = cctx.stream_reader(fh) + ob = json.load(reader) + yield from ob + + def read_dat(self, file): + with open(file, 'rb') as fh: + cctx = zstandard.ZstdDecompressor() + reader = cctx.stream_reader(fh) + while True: + ln = reader.read(16).decode('UTF-8') + if not ln: + break + + ln = int(ln) + + yield reader.read(ln).decode('UTF-8') + + def read_jsonl(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key='text'): + with jsonlines.open(file) as rdr: + yield from handle_jsonl(rdr, get_meta, autojoin_paragraphs, para_joiner, key) + + def read_jsonl_zst(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key='text'): + with open(file, 'rb') as fh: + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(fh)) + rdr = jsonlines.Reader(reader) + yield from handle_jsonl(rdr, get_meta, autojoin_paragraphs, para_joiner, key) + + def read_jsonl_tar(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key='text'): + with open(file, 'rb') as fh: + for f in tarfile_reader(fh, streaming=True): + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(f)) + rdr = jsonlines.Reader(reader) + yield from handle_jsonl(rdr, get_meta, autojoin_paragraphs, para_joiner, key) + f.close() + + def read_owt(self, file): + tar = tarfile.open(file, encoding='utf-8') + utf8reader = codecs.getreader('utf-8') + + for name in tar.getmembers(): + fp = tar.extractfile(name) + inner_tar = tarfile.open(fileobj=fp, encoding='utf-8') + for inner_name in inner_tar.getmembers(): + inner_fp = utf8reader(inner_tar.extractfile(inner_name)) + contents = inner_fp.read() + yield contents + + def read_owt_subset(self, file): + utf8reader = codecs.getreader('utf-8') + tar = tarfile.open(file, encoding='utf-8') + for name in tar.getmembers(): + fp = utf8reader(tar.extractfile(name)) + contents = fp.read() + yield contents + + +class Archive: + def __init__(self, out_dir, compression_level=3, threads=8): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.i = 0 + + self.fh = open(self.out_dir + '/current_chunk_incomplete', 'wb') + self.cctx = zstandard.ZstdCompressor(level=compression_level, threads=threads) + self.compressor = self.cctx.stream_writer(self.fh) + + + def add_data(self, data, meta={}): + self.compressor.write(json.dumps({'text': data, 'meta': meta}).encode('UTF-8') + b'\n') + + def commit(self, archive_name='default'): + fname = self.out_dir + '/data_' + str(self.i) + '_time' + str(int(time.time())) + '_' + archive_name + '.jsonl.zst' + self.compressor.flush(zstandard.FLUSH_FRAME) + + self.fh.flush() + self.fh.close() + os.rename(self.out_dir + '/current_chunk_incomplete', fname) + self.fh = open(self.out_dir + '/current_chunk_incomplete', 'wb') + self.compressor = self.cctx.stream_writer(self.fh) + + self.i += 1 + + +class DatArchive: + def __init__(self, out_dir): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.data = [] + self.i = 0 + if os.path.exists(out_dir) and len(os.listdir(out_dir))> 0: + self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self, archive_name=None): + # TODO: streaming + cctx = zstandard.ZstdCompressor(level=3) + + if archive_name is None: + archive_name = str(int(time.time())) + + res = b''.join(map(lambda x: ("%016d" % len(x)).encode('UTF-8') + x, map(lambda x: x.encode('UTF-8'), self.data))) + cdata = cctx.compress(res) + + with open(self.out_dir + '/data_' + str(self.i) + '_' + archive_name + '.dat.zst', 'wb') as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] + +class JSONArchive: + def __init__(self, out_dir): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.data = [] + self.i = 0 + if os.path.exists(out_dir) and len(os.listdir(out_dir))> 0: + self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self): + cctx = zstandard.ZstdCompressor(level=3) + + cdata = cctx.compress(json.dumps(self.data).encode('UTF-8')) + with open(self.out_dir + '/data_' + str(self.i) + '_' + str(int(time.time())) + '.json.zst', 'wb') as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] diff --git a/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py b/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py new file mode 100644 index 0000000..0678e27 --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py @@ -0,0 +1,335 @@ +from transformers import AutoTokenizer +from tokenizer import init_tokenizer + + +def load_tokenizer(model_path, tokenizer_type=None): + """ + Load tokenizer from the given + """ + + def load_tokenizer_manual(model_path, tokenizer_type): + """ + Load tokenizer by the concrete Tokenizer class instead of AutoTokenizer + """ + try: + if tokenizer_type.lower() == "LlamaTokenizer".lower(): + return LlamaTokenizer.from_pretrained(model_path) + + raise Exception(f"Unsupported tokenizer type {tokenizer_type}") + except: + raise Exception(f"Unable to load tokenizer {tokenizer_type} from the given path: {model_path}") + + def load_tokenizer_auto(model_path): + """ + Load tokenizer from the given path by HuggingFace AutoTokenizer + """ + try: + # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) # support CodeLlama + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return tokenizer + except: + raise Exception( + f'Unable to load tokenizer from the given path: {model_path} using auto mode.\nPlease specify the tokenizer type with the command argument "--tokenizer-type" and retry.' + ) + + # First, try to load tokenizer by huggingface AutoTokenizer, If fail, try another manual way + try: + return load_tokenizer_auto(model_path) + except Exception as e: + print(str(e)) + if tokenizer_type is not None: + try: + tokenizer = load_tokenizer_manual(model_path, tokenizer_type) + return tokenizer + except Exception as ee: + raise ee + + +class PackPFTEncoder: + """ + A sample of this format will be: + <|role_start|>system<|role_end|> content of system_1 + <|role_start|>human<|role_end|> content of human_1 + <|role_start|>bot<|role_end|> content of bot_1 + <|endoftext|> + <|role_start|>system<|role_end|> content of system_2 + <|role_start|>human<|role_end|> content of human_2 + <|role_start|>bot<|role_end|> content of bot_2 + <|endoftext|> + <|role_start|>human<|role_end|> content of human_3 + <|role_start|>bot<|role_end|> content of bot_3 + <|endoftext|> + .... + <|role_start|>human<|role_end|> content of human_n + <|role_start|>bot<|role_end|> content of bot_n + <|endoftext|> + + <|pad|><|pad|>...<|pad|> + + system part is optional, i.e. '<|role_start|>system<|role_end|> content of system_i' + """ + + def __init__(self, seq_length, eod_token_id, pad_token_id, role_start_tag, role_end_tag, mode="pft"): + self.mode = mode + self.seq_length = seq_length + self.eod_token_id = eod_token_id + self.pad_token_id = pad_token_id + self.role_start_tag = role_start_tag + self.role_end_tag = role_end_tag + + def initializer(self, model_path, tokenizer_type=None): + # Use Encoder class as a container for global data + assert model_path is not None + self.tokenizer = load_tokenizer(model_path, tokenizer_type) + + def encode(self, item): + encode_res = { + "input_ids": [], + } + + item_len = sum([len(x["content"]) for x in item["chat_rounds"]]) + for token_res in self.tokenize_chat_prompt(item): + for k, v in token_res.items(): + encode_res[k].append(v) + return encode_res, item_len + + def tokenize_chat_prompt(self, item): + # role_start_marker = self.tokenizer.encode(self.role_start_tag, add_special_tokens=False) + # role_end_marker = self.tokenizer.encode(self.role_end_tag, add_special_tokens=False) + end_marker = [self.eod_token_id] + + input_ids = [] + raw_input = "" + # loss_mask = [] + for chat_round in item["chat_rounds"]: + role = chat_round["role"].strip() + # skip system prompt + # if role == 'system': + # continue + + content = chat_round["content"] + content = content if content.endswith("\n") else f"{content}\n" + text = f"{self.role_start_tag}{role}{self.role_end_tag}{content}" + chat_input_ids = self.tokenizer.encode(text, add_special_tokens=False) + + if role != "bot": + chat_input_ids = chat_input_ids + else: + chat_input_ids = chat_input_ids + end_marker + + input_ids += chat_input_ids + + # if this sample's length is more than the specified max length, drop it + # here, we don't add padding tokens for a single sample, however, we will append padding tokens for a combinated samaple + if len(input_ids)> self.seq_length: + yield {} + else: + yield {"input_ids": input_ids} + + def padding(self, key, data): + assert len(data) <= self.seq_length, f"padding sequence: {len(data)}> {self.seq_length}" + if key == "input_ids": + return data + [self.pad_token_id] * (self.seq_length - len(data)) + + if key == "loss_mask": + return data + [0] * (self.seq_length - len(data)) + + raise Exception("Should not reach here. There must be something wrong.") + + +class PackSFTEncoder: + """ + A sample of this format will be: + <|role_start|>system<|role_end|> content of system_1 + <|role_start|>human<|role_end|> content of human_1 + <|role_start|>bot<|role_end|> content of bot_1 + <|endoftext|> + <|role_start|>system<|role_end|> content of system_2 + <|role_start|>human<|role_end|> content of human_2 + <|role_start|>bot<|role_end|> content of bot_2 + <|endoftext|> + <|role_start|>human<|role_end|> content of human_3 + <|role_start|>bot<|role_end|> content of bot_3 + <|endoftext|> + .... + <|role_start|>human<|role_end|> content of human_n + <|role_start|>bot<|role_end|> content of bot_n + <|endoftext|> + + <|pad|><|pad|>...<|pad|> + + system part is optional, i.e. '<|role_start|>system<|role_end|> content of system_i' + """ + + def __init__(self, seq_length, eod_token, role_start_tag, role_end_tag, mode="sft"): + self.mode = mode + self.seq_length = seq_length + self.eod_token = eod_token + self.role_start_tag = role_start_tag + self.role_end_tag = role_end_tag + + def initializer(self, model_path, tokenizer_type=None): + # Use Encoder class as a container for global data + assert model_path is not None + self.tokenizer = load_tokenizer( + model_path, tokenizer_type + ) # AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + def encode(self, item): + encode_res = {"input_ids": [], "raw_input": []} + + item_len = sum([len(x["content"]) for x in item["chat_rounds"]]) + for token_res in self.tokenize_chat_prompt(item): + for k, v in token_res.items(): + encode_res[k].append(v) + return encode_res, item_len + + def tokenize_chat_prompt(self, item): + role_start_marker = self.tokenizer.encode(self.role_start_tag, add_special_tokens=False) + role_end_marker = self.tokenizer.encode(self.role_end_tag, add_special_tokens=False) + end_marker = [self.tokenizer.convert_tokens_to_ids(self.eod_token)] + + input_ids = [] + raw_input = "" + # loss_mask = [] + for chat_round in item["chat_rounds"]: + role = chat_round["role"] + content = chat_round["content"] + content = content if content.endswith("\n") else f"{content}\n" + chat_input_ids = self.tokenizer.encode(content, add_special_tokens=False) + role_input_ids = self.tokenizer.encode(role, add_special_tokens=False) + role_raw_input = "" + + if role != "bot": + # chat_loss_mask = [0] * len(role_start_marker) + [0] * len(role_input_ids) + [0] * len(role_end_marker) + [0] * len(chat_input_ids) + chat_input_ids = role_start_marker + role_input_ids + role_end_marker + chat_input_ids + role_raw_input = ROLE_START_MARKER + role + ROLE_END_MARKER + content + elif role == "human": + # chat_loss_mask = [0] * len(role_start_marker) + [0] * len(role_input_ids) + [0] * len(role_end_marker) + [1] * len(chat_input_ids) + [1] * len(end_marker) + chat_input_ids = role_start_marker + role_input_ids + role_end_marker + chat_input_ids + end_marker + role_raw_input = ROLE_START_MARKER + role + ROLE_END_MARKER + content + self.eod_token + + input_ids += chat_input_ids + raw_input += role_raw_input + # loss_mask += chat_loss_mask + + # assert len(input_ids) == len(loss_mask) + + # if this sample's length is more than the specified max length, drop it + # here, we don't add padding tokens for a single sample, however, we will append padding tokens for a combinated samaple + if len(input_ids)> self.seq_length: + yield {} + else: + yield { + "input_ids": input_ids, + "raw_input": raw_input, + # "loss_mask": loss_mask + } + + def padding(self, key, data, pad_token_id): + assert len(data) <= self.seq_length, f"padding sequence: {len(data)}> {self.seq_length}" + if key == "input_ids": + return data + [pad_token_id] * (self.seq_length - len(data)) + + if key == "loss_mask": + return data + [0] * (self.seq_length - len(data)) + + raise Exception("Should not reach here. There must be something wrong.") + + +class PackSSTBinEncoder: + """ + A sample of this format will be: + content of sample_1 + content of sample_2 + ... + content of sample_n + <|pad|><|pad|>...<|pad|> + """ + + def __init__(self, seq_length, model_path): + self.seq_length = seq_length + self.model_path = model_path + + def initializer(self): + # Use Encoder class as a container for global data + assert self.model_path is not None + # self.tokenizer = load_tokenizer(model_path, tokenizer_type) #AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + # PackSSTBinEncoder.tokenizer = load_tokenizer(self.model_path, self.tokenizer_type) + PackSSTBinEncoder.tokenizer = init_tokenizer(self.model_path) + + def _encode_content(self, item, encode_res): + if "content" in item: + content = item["content"] + else: + content = item["text"] + + item_len = len(content) + + input_ids = self.tokenize_string(content) + encode_res["input_ids"].append(input_ids) + + return encode_res, item_len + + def _encode_chatml(self, item, encode_res): + input_ids = [] + item_len = 0 + one_round_content = "" + for i in range(len(item["chat_rounds"])): + chat_round = item["chat_rounds"][i] + role = chat_round["role"] + content = chat_round["content"] + content = content if content.endswith("\n") else f"{content}\n" + if role.lower() == "system": + continue + if role.lower() == "human": + one_round_content = content + else: + one_round_content += content + input_ids += self.tokenize_string(one_round_content) + item_len += len(one_round_content) + + encode_res["input_ids"].append(input_ids) + + return encode_res, item_len + + def encode(self, item): + encode_res = { + "input_ids": [], + } + + try: + if item is None: + encode_res["input_ids"].append([]) + return encode_res, 0 + + if "content" in item or "text" in item: + return self._encode_content(item, encode_res) + + if "chat_rounds" in item: + return self._encode_chatml(item, encode_res) + except Exception as e: + print("####JSON Exception", e, str(item)) + encode_res["input_ids"].append([]) + return encode_res, 0 + + raise Exception("Unsupported Format!") + + def tokenize_string(self, text): + end_marker = [PackSSTBinEncoder.tokenizer.eos_token_id] + + input_ids = [] + try: + input_ids = PackSSTBinEncoder.tokenizer.encode(text, add_special_tokens=False) + input_ids = input_ids + end_marker + return input_ids + except Exception as e: + print("####Tokenization Exception:", e, text) + return [] + except BaseException as e: + print("####Tokenization BaseException:", e, "Length of text", len(text)) + return [] + + def padding(self, data, pad_token_id): + assert len(data) <= self.seq_length, f"padding sequence: {len(data)}> {self.seq_length}" + return data + [pad_token_id] * (self.seq_length - len(data)) diff --git a/mftcoder_accelerate/src/offline_tokenization/writer.py b/mftcoder_accelerate/src/offline_tokenization/writer.py new file mode 100644 index 0000000..ab526a7 --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/writer.py @@ -0,0 +1,42 @@ + +import threading +import fcntl +import json + +class JSONLWriter(): + """ + A writer used to save jsonl lines into a file. + """ + def __init__(self, output_path, dataset_name): + self.output_path = output_path + self.out_file = open(output_path, 'w') + self.cache = [] + self.cache_size = 4096 + self.dataset_name = dataset_name + self.index = 0 + + def pack_into_jsonl(self, line_text): + new_item = { + "data_name": self.dataset_name, + "id": self.index, + "content": line_text + } + + return new_item + + + def add_item(self, line_text): + if len(self.cache)>= self.cache_size: + self.flush() + + item = self.pack_into_jsonl(line_text) + self.cache.append(json.dumps(item)) + self.index += 1 + + + def flush(self): + content = '\n'.join(self.cache) + fcntl.flock(self.out_file, fcntl.LOCK_EX) + self.out_file.write(f'{content}\n') + fcntl.flock(self.out_file, fcntl.LOCK_UN) + self.cache = [] diff --git a/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py index 1f2ff59..26f8ec1 100644 --- a/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py +++ b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py @@ -4,6 +4,7 @@ Merge base and lora adaptor """ + import os import sys import time @@ -22,8 +23,7 @@ sys.path.insert(0, parent_dir) print("In merge_base_and_lora_to_hf.py, sys path:", sys.path) -from pefts.model_mapping import MODEL_SPECIAL_TOKENS -from tokenizer.chat_template import MFTCoder_template +from tokenizer import init_tokenizer def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): @@ -43,7 +43,7 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): print(f"File {filename} does not exist in {mode_path}") -if __name__ == '__main__': +if __name__ == "__main__": # arguments parser = argparse.ArgumentParser() @@ -59,30 +59,21 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): save_path = args.merged_output_path t0 = time.time() - config = {"model_type": model_type} - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - tokenizer.chat_template = MFTCoder_template + + tokenizer = init_tokenizer(args.base_model_or_path) base_model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, + # torch_dtype=torch.float32, return_dict=True, - device_map="auto" + device_map="auto", ) print("--------------------------------------Base Model--------------------------------------------") print(base_model) print("--------------------------------------------------------------------------------------------") - # DEAL with eos_token_id and pad_token_id - eos_token = MODEL_SPECIAL_TOKENS[config['model_type']]['eos_token'] - pad_token = MODEL_SPECIAL_TOKENS[config['model_type']]['pad_token'] - base_model.config.eos_token = eos_token - base_model.config.pad_token = pad_token - base_model.config.eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) - base_model.config.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) - print(f"Finetuned eos_token: {eos_token}, eos_token_id: {tokenizer.convert_tokens_to_ids(eos_token)}") - print(f"Finetuned pad_token: {pad_token}, pad_token_id: {tokenizer.convert_tokens_to_ids(pad_token)}") print("-----------------------------------Base Model Config----------------------------------------") print(base_model.config) print("--------------------------------------------------------------------------------------------") @@ -90,6 +81,8 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): # merge, save model and tokenizer model_to_merge = PeftModel.from_pretrained(base_model, lora_adapter) merged_model = model_to_merge.merge_and_unload() + # merged_model.to(torch.bfloat16) + print("---------------------------------Merged Model Config----------------------------------------") print(merged_model.config) print("--------------------------------------------------------------------------------------------") @@ -101,8 +94,8 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): if model_type.lower() == "deepseek": copy_tokenizer_files( model_path, - ["tokenizer.model", "tokenizer.json", "tokenizer_config.json", 'special_tokens_map.json'], - save_path + ["tokenizer.model", "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"], + save_path, ) else: tokenizer.save_pretrained(save_path) diff --git a/mftcoder_accelerate/src/pefts/mft_accelerate.py b/mftcoder_accelerate/src/pefts/mft_accelerate.py index acfb422..0a0d42a 100644 --- a/mftcoder_accelerate/src/pefts/mft_accelerate.py +++ b/mftcoder_accelerate/src/pefts/mft_accelerate.py @@ -1,15 +1,13 @@ """ # @author Chaoyu Chen -# @date 2024年5月20日 +# @date 2024年10月24日 # @module mft_accelerate.py -Accelerate + DeepSpeed zero2/zero3/FSDP + Data Parallelism -QLoRA/LoRA/Full + MFT/MPT, resource and parameters efficient training +Accelerate + DeepSpeed/FSDP + QLoRA/LoRA/Full + Multi-task Finetuning Entry """ - import os import sys import argparse @@ -31,7 +29,6 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, - LlamaTokenizer, get_linear_schedule_with_warmup, set_seed, BitsAndBytesConfig, @@ -41,11 +38,14 @@ LoraConfig, TaskType, get_peft_model, - # prepare_model_for_kbit_training, + prepare_model_for_kbit_training, PeftModel, ) from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration from accelerate.logging import get_logger +from datetime import timedelta +from accelerate.utils import InitProcessGroupKwargs +from transformers.optimization import Adafactor # insert src as import path current_path = os.path.abspath(__file__) @@ -58,28 +58,13 @@ from data.data_utils import load_dataset_from_bin from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK -from pefts.trainer import MftTrainer -from pefts.arguments import TrainArgs -from pefts.model_mapping import MODEL_TYPES, FULL_LORA_TARGETING_MODULES, MODEL_SPECIAL_TOKENS, CUSTOMIZE +from pefts.mft_trainer import MftTrainer +from pefts.mft_arguments import MftTrainArgs +from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS logger = get_logger(__name__) -SUPPORT_FA2_IN_TRANSFORMERS = [ - "code_llama", - "llama", - "deepseek", - "mistral", - "mixtral", - "gpt_neox", - "phi", - "starcoder", - "qwen2", - "qwen2_moe", - "gemma", - "starcoder2" -] - def get_task_mask(args, task_id): task_num = len(TASK2ID) @@ -90,7 +75,7 @@ def get_task_mask(args, task_id): def get_attention_mask_and_position_ids(data): - """Build masks and position ids if you need to""" + """Build masks and position id for left to right model.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() @@ -166,48 +151,6 @@ def __call__(self, instances): return result_batch -def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): - """ - This method wraps the entire protocol for preparing a model before running a training. - This includes: - 1- Cast the layernorm in fp32 - 2- making output embedding layer require grads - 3- Add the upcasting of the lm head to fp32 - - Args: - model, (`transformers.PreTrainedModel`): - The loaded model from `transformers` - """ - loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) - - is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" - for name, param in model.named_parameters(): - # freeze base model's layers - param.requires_grad = False - - if not is_gptq_quantized: - # cast all non INT8 parameters to fp32 - for param in model.parameters(): - if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): - param.data = param.data.to(torch.float32) - - if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing: - # For backward compatibility - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # enable gradient checkpointing for memory efficiency - model.gradient_checkpointing_enable() - - return model - - def pprint_args(args, accelerator): # 计算所有键的最大字符串长度 max_key_length = max(len(str(key)) for key in vars(args).keys()) @@ -239,7 +182,7 @@ def prepare_args(): # parse args from cofig.json # args = argparse.Namespace(**train_config) - args = TrainArgs(**train_config) + args = MftTrainArgs(**train_config) # override args by cli arguments if parsed.data_paths: @@ -260,8 +203,6 @@ def prepare_args(): args.distributed_type = parsed.distributed_type # refactor args - args.eos_token = MODEL_SPECIAL_TOKENS[args.model_type]["eos_token"] - args.pad_token = MODEL_SPECIAL_TOKENS[args.model_type]["pad_token"] if args.peft_type == "qlora": print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'") @@ -276,7 +217,7 @@ def prepare_args(): # generate TASK2ID, ID2TASK generate_task_id(args.data_paths) - if args.weighted_loss_mode == "selfpaced": + if args.weighted_loss_mode == "coba": args.task_weights = [1.0] * len(ID2TASK) elif args.task_weights is not None: args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")] @@ -299,6 +240,8 @@ def main(): set_seed(args.seed) # define accelerator + init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) + if args.distributed_type and args.distributed_type.lower() == "fsdp": fsdp_plugin = FullyShardedDataParallelPlugin( # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), @@ -309,15 +252,18 @@ def main(): cpu_offload=False, ) accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, fsdp_plugin=fsdp_plugin, - dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True), + gradient_accumulation_steps=args.gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], ) else: accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, - dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True), + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], ) - + # print key infos accelerator.print("In mft_accelerate.py, sys path:", sys.path) accelerator.print(f"transformers.__version__: {transformers.__version__}") @@ -339,7 +285,7 @@ def main(): with open(os.path.join(args.output_dir, "latest"), "r") as fl: latest = json.load(fl) accelerator.print(f"[INFO] Existing latest: {latest}") - + if args.auto_resume and args.resume_from_checkpoint is None and latest: if args.peft_type: args.resume_from_checkpoint = latest["latest_ckpt"] @@ -349,11 +295,6 @@ def main(): args.learning_rate = latest["lr"] elif args.resume_from_checkpoint and (not args.peft_type): args.pretrained_model_path = args.resume_from_checkpoint - - # if latest: - # scheduler_last_ep = latest["scheduler_last_ep"] - # else: - # scheduler_last_ep = -1 # logger logging.basicConfig( @@ -400,11 +341,11 @@ def main(): max_memory = {i: max_memory for i in range(n_gpus)} accelerator.print("max memory: ", max_memory, n_gpus) - # target_modules + # target_modules, default all-linear for all linear layers if args.target_modules: target_modules = args.target_modules else: - target_modules = FULL_LORA_TARGETING_MODULES[args.model_type] + target_modules = "all-linear" # peft config if args.peft_type: @@ -418,53 +359,54 @@ def main(): bias="lora_only", ) - # new special tokens + # # 是否要加入新的special tokens # num_added_toks = tokenizer.tokenizer.add_special_tokens(["", ""]) # accelerator.print("We have added", num_added_toks, "tokens") # accelerator.print(f"role marker tokens {tokenizer.convert_tokens_to_ids('')} {tokenizer.convert_tokens_to_ids('')}, resized tokenizer_size: {len(tokenizer)}") # creating base model ModelClass = MODEL_TYPES[args.model_type] - if args.model_type in SUPPORT_FA2_IN_TRANSFORMERS and not CUSTOMIZE: - accelerator.print(f"[INFO] Model Type {args.model_type} " f"is supported FA2 by Transformers and we use it") + if args.model_type in SUPPORT_IN_TRANSFORMERS: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers") model = ModelClass.from_pretrained( args.pretrained_model_path, attn_implementation=args.attn_implementation, torch_dtype=torch.bfloat16, - quantization_config=BitsAndBytesConfig( - load_in_4bit=(args.quantization == "4bit"), - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_quant_storage=torch.bfloat16, - ) - if args.quantization == "4bit" - else None, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), ) else: - accelerator.print( - f"[INFO] Model Type {args.model_type} " - f"is NOT supported officially by Transformers " - f"and we use published modeling_xxx.py(may be modified by us)" - ) + accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code") model = ModelClass.from_pretrained( args.pretrained_model_path, torch_dtype=torch.bfloat16, - quantization_config=BitsAndBytesConfig( - load_in_4bit=(args.quantization == "4bit"), - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_quant_storage=torch.bfloat16, - ) - if args.quantization == "4bit" - else None, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), ) # build a tokenizer for possible resizing or saving tokenizer = build_tokenizer(args) # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, # i.e. the length of the tokenizer. + # 如果新增special tokens, 需要resize input embedding 和output embedding # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) accelerator.print("Model load_in_4bit: ", args.quantization == "4bit") @@ -500,7 +442,8 @@ def main(): logging.info(f"model loading time: {t2 - t1:.4f}") model.config.use_cache = False # silence the warnings. Please re-enable for inference! - model.config.use_logn_attn = False # special for qwen model + if hasattr(model.config, "use_logn_attn"): + model.config.use_logn_attn = False # special for qwen model # load balance for moe training if hasattr(model.config, "output_router_logits"): model.config.output_router_logits = True @@ -516,13 +459,18 @@ def main(): pin_memory=True, drop_last=True, ) - valid_dataloader = DataLoader( - valid_dataset, - collate_fn=DataCollatorForMFTDataset(args), - batch_size=args.per_device_eval_batch_size, - pin_memory=True, - drop_last=True, - ) + if valid_dataset: + valid_dataloader = DataLoader( + valid_dataset, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_eval_batch_size, + pin_memory=True, + drop_last=True, + ) + else: + valid_dataloader = None + + # optimizer if accelerator.distributed_type == DistributedType.DEEPSPEED: accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED") # from deepspeed.ops.adam import FusedAdam as Adam @@ -545,8 +493,6 @@ def main(): lr=args.learning_rate, betas=(0.9, 0.999), ) - # for group in optimizer.param_groups: - # group.setdefault("initial_lr", group["lr"]) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -554,7 +500,9 @@ def main(): if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - + if isinstance(args.num_warmup_steps, float) and args.num_warmup_steps < 1.0: + args.num_warmup_steps = int(args.max_train_steps * args.num_warmup_steps) // accelerator.num_processes + accelerator.print(f"num_warmup_steps: {args.num_warmup_steps}") lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, @@ -564,14 +512,25 @@ def main(): ) # prepare all if accelerator.distributed_type == DistributedType.DEEPSPEED: - (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare( - model, train_dataloader, valid_dataloader, optimizer, lr_scheduler - ) + if valid_dataloader: + (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, valid_dataloader, optimizer, lr_scheduler + ) + else: + (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, optimizer, lr_scheduler + ) + # prepare all except model, which is prepared before elif accelerator.distributed_type == DistributedType.FSDP: - (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare( - optimizer, train_dataloader, valid_dataloader, lr_scheduler - ) + if valid_dataloader: + (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, valid_dataloader, lr_scheduler + ) + else: + (optimizer, train_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) print(model.device) accelerator.print(model) # accelerator.print(model.config) @@ -590,7 +549,7 @@ def main(): accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}") elif getattr(accelerator.state, "fsdp_plugin", None): accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}") - + trainer = MftTrainer( accelerator=accelerator, model=model, @@ -605,6 +564,7 @@ def main(): args=args, ) trainer.accelerate_train() + logger.info(f"Training Finished!") if __name__ == "__main__": diff --git a/mftcoder_accelerate/src/pefts/arguments.py b/mftcoder_accelerate/src/pefts/mft_arguments.py similarity index 87% rename from mftcoder_accelerate/src/pefts/arguments.py rename to mftcoder_accelerate/src/pefts/mft_arguments.py index 1403c4f..9fee1cd 100644 --- a/mftcoder_accelerate/src/pefts/arguments.py +++ b/mftcoder_accelerate/src/pefts/mft_arguments.py @@ -2,15 +2,15 @@ # @author Chaoyu Chen # @date 2023/10/19 -accelerate + deepspeed zero stage2 + Data Parallelism -MFT Training +training arguments """ + from dataclasses import dataclass, asdict from typing import List, Union @dataclass -class TrainArgs: +class MftTrainArgs: # train data paths on shared FS data_paths: Union[str, List[str]] @@ -47,7 +47,7 @@ class TrainArgs: # sft or sst tokenize_mode: str = "sft" - # case3 or case4 + # mft loss mode weighted_loss_mode: str = "case3" # lora or qlora or None(for full-parameter training) @@ -93,7 +93,7 @@ class TrainArgs: lr_scheduler_type: str = "cosine" # num_warmup_steps - num_warmup_steps: int = 300 + num_warmup_steps: Union[int, float] = 0.05 # num_train_epochs num_train_epochs: int = 4 @@ -141,14 +141,16 @@ class TrainArgs: # if dynamic padding use_dynamic_padding: bool = True - # interval of update per task train weight in selfpaced - selfpaced_interval: int = 1 - # history length of sample valid loss used to fit the slope curve in selfpaced - selfpaced_history_length: int = 100 - # the number of mini valid batches sampled at each interval - selfpaced_sample_valid_num: int = 1 - # scale factor before softmax - selfpaced_scale_factor: int = 50 + # warm-up steps for CoBa, recommand the number of valid batches + coba_warmup_steps: int = 100 + # history length of sample valid loss used to fit the slope curve in CoBa, recommand [2*coba_warmup_steps,5*coba_warmup_steps] + coba_history_length: int = 200 + # temperature for divergence factor in CoBa + coba_tau: int = 5 + # iteration interval of update per task train weight in CoBa + coba_update_interval: int = 1 + # the number of mini valid batches sampled at each updated iteration interval + coba_sample_valid_num: int = 1 # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} attn_implementation: str = "flash_attention_2" @@ -158,6 +160,9 @@ class TrainArgs: role_markers: Union[None, dict] = None distributed_type: Union[None, str] = None + + init_timeout_seconds: Union[None, int] = 3600 + # legacy, leave them use_xformers: bool = True trust_remote_code: bool = True diff --git a/mftcoder_accelerate/src/pefts/trainer.py b/mftcoder_accelerate/src/pefts/mft_trainer.py similarity index 89% rename from mftcoder_accelerate/src/pefts/trainer.py rename to mftcoder_accelerate/src/pefts/mft_trainer.py index 3cbc25e..a2b00fb 100644 --- a/mftcoder_accelerate/src/pefts/trainer.py +++ b/mftcoder_accelerate/src/pefts/mft_trainer.py @@ -1,10 +1,10 @@ """ -# @author Chaoyu Chen +# @author qumu # @date 2024/4/12 # @module trainer.py Accelerate + DeepSpeed/FSDP -QLoRA/LoRA/Full + SFT/MFT/MPT +QLoRA/LoRA/Full + SFT/MFT Trainer """ @@ -33,7 +33,7 @@ # sys.path.append("..") from utils.common_utils import generate_task_id, TASK2ID, ID2TASK -from utils.loss_utils import loss_func_mft, SelfpacedStatus, load_balancing_loss_func +from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func logger = get_logger(__name__) @@ -208,7 +208,7 @@ def accelerate_saving_checkpoint(self, output_dir: str, completed_steps: int): ) else: self.tokenizer.save_pretrained(output_dir) - + sf = os.path.join(output_dir, "model.safetensors") index_file = os.path.join(output_dir, "model.safetensors.index.json") if os.path.isfile(sf) and os.path.isfile(index_file): @@ -219,8 +219,6 @@ def accelerate_saving_checkpoint(self, output_dir: str, completed_steps: int): latest = { "latest_ckpt": output_dir, "lr": self.optimizer.param_groups[0]["lr"], - # 1 step back because ckping is after schuduler.step() - # "scheduler_last_ep": self.lr_scheduler.state_dict().get("last_epoch", 0) - 1, } with open(os.path.join(self.args.output_dir, "latest"), "w") as f: json.dump(latest, f, indent=2) @@ -237,7 +235,7 @@ def accelerate_monitor( reduce_task_loss, reduce_task_exist, completed_steps, - selfpaced_status=None, + coba_status=None, ): """ gather reduce_loss and reduce_task_loss from all N devices. @@ -261,27 +259,27 @@ def accelerate_monitor( f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]", main_process_only=True, ) - if selfpaced_status is not None: - if completed_steps> selfpaced_status.selfpaced_history_length: - selfpaced_status.log_per_task_weight = selfpaced_status.log_per_task_weight / torch.sum( - selfpaced_status.log_per_task_weight + if coba_status is not None: + if completed_steps> coba_status.coba_warmup_steps: + coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum( + coba_status.log_per_task_weight ) else: - selfpaced_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) + coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) logger.info( - f"[TRAIN][per_task_train_weight={selfpaced_status.log_per_task_weight}]", main_process_only=True + f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True ) train_log_dict = {"Loss/train": train_loss} for i in range(len(ID2TASK)): train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i] - if selfpaced_status is not None: - train_log_dict[f"{ID2TASK[i]}_selfpaced_weight/train"] = selfpaced_status.log_per_task_weight[i].item() + if coba_status is not None: + train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item() if self.accelerator.is_main_process: write_tensorboard(self.summary_writer, train_log_dict, completed_steps) - if selfpaced_status is not None: - selfpaced_status.log_per_task_weight = torch.zeros(len(ID2TASK)) + if coba_status is not None: + coba_status.log_per_task_weight = torch.zeros(len(ID2TASK)) def accelerate_evaluate( self, @@ -412,24 +410,35 @@ def accelerate_train(self): reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) per_task_weight = self.args.task_weights - if self.args.weighted_loss_mode == "selfpaced": - selfpaced_status = SelfpacedStatus( - self.args.selfpaced_scale_factor, - self.args.selfpaced_interval, - self.args.selfpaced_history_length, - self.args.selfpaced_sample_valid_num, + if self.args.weighted_loss_mode == "coba": + self.model.eval() + eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate( + completed_steps, + 0, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + coba_status = CoBaStatus( + self.args.coba_warmup_steps, + self.args.coba_history_length, + self.args.coba_tau, + self.args.coba_update_interval, + self.args.coba_sample_valid_num, self.valid_dataloader, ) - selfpaced_status.sample_valid_batch(self.model, completed_steps) - selfpaced_status.valid_iterator = iter(selfpaced_status.valid_dataloader) + coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device) + coba_status.sample_valid_batch(self.model, completed_steps) + logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) else: - selfpaced_status = None + coba_status = None # Training Loop! for epoch in range(starting_epoch, self.args.num_train_epochs): - # set_epoch + # set_epoch # self.train_dataloader.set_epoch(epoch) - + # if we early stop by some ckpts not converging if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: break @@ -459,13 +468,15 @@ def accelerate_train(self): ) if ( - self.args.weighted_loss_mode == "selfpaced" - and step % self.args.gradient_accumulation_steps == 0 - and completed_steps % self.args.selfpaced_interval == 0 - and completed_steps>= self.args.selfpaced_history_length + self.args.weighted_loss_mode == "coba" + and self.accelerator.sync_gradients + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps>= self.args.coba_warmup_steps ): - per_task_weight = selfpaced_status.compute_per_task_weight(completed_steps=completed_steps) - selfpaced_status.log_per_task_weight += per_task_weight + with torch.no_grad(): + per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps) + coba_status.log_per_task_weight += per_task_weight + # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True) # loss loss, task_loss, _ = loss_func_mft( @@ -520,11 +531,12 @@ def accelerate_train(self): # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done. if self.accelerator.sync_gradients: if ( - self.args.weighted_loss_mode == "selfpaced" - and completed_steps % self.args.selfpaced_interval == 0 + self.args.weighted_loss_mode == "coba" + and completed_steps % self.args.coba_update_interval == 0 and completed_steps>= 1 ): - selfpaced_status.sample_valid_batch(self.model, completed_steps) + coba_status.sample_valid_batch(self.model, completed_steps) + # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) # progress_bar.update(1) completed_steps += 1 @@ -538,7 +550,7 @@ def accelerate_train(self): reduce_task_loss, reduce_task_exist, completed_steps, - selfpaced_status, + coba_status, ) # reset reduce_loss reduce_loss = torch.tensor(0.0).to(self.model.device) @@ -554,7 +566,7 @@ def accelerate_train(self): self.accelerate_saving_checkpoint(output_dir, completed_steps) # steps evaluation - if completed_steps % self.args.evaluation_steps == 0: + if completed_steps % self.args.evaluation_steps == 0 and self.valid_dataloader: self.model.eval() eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate( completed_steps, diff --git a/mftcoder_accelerate/src/pefts/model_mapping.py b/mftcoder_accelerate/src/pefts/model_mapping.py deleted file mode 100644 index 7824e8f..0000000 --- a/mftcoder_accelerate/src/pefts/model_mapping.py +++ /dev/null @@ -1,153 +0,0 @@ -""" - # @author Chaoyu Chen - # @date 2024年5月20日 - - Manage supported models and their special token used in training. - Default targeting modules for LoRA/QLora - 4.40 is stable now -""" - -# Models that have both cutomized modeling and Transformers modeling - -CUSTOMIZE = False -if CUSTOMIZE: - from model.code_llama.modeling_llama import LlamaForCausalLM - from model.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM - from model.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM -else: - from transformers import ( - GPTNeoXForCausalLM, - GPTBigCodeForCausalLM, - LlamaForCausalLM, - ) - -# Models that Transformers support Code and FA2 when flash_attn>=2.1.0 -from transformers import ( - MistralForCausalLM, - MixtralForCausalLM, - PhiForCausalLM, - GemmaForCausalLM, - Qwen2ForCausalLM, - Qwen2MoeForCausalLM, - Starcoder2ForCausalLM, -) -# Models that Code from "remote_code" -from model.aquila2.modeling_aquila import AquilaForCausalLM -from model.baichuan2.modeling_baichuan import BaichuanForCausalLM -from model.qwen.modeling_qwen import QWenLMHeadModel -from model.chatglm2.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration2 -from model.chatglm3.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration3 -# from model.phi.modeling_mixformer_sequential import MixFormerSequentialForCausalLM - - -MODEL_TYPES = { - "aquila2": AquilaForCausalLM, - "baichuan": BaichuanForCausalLM, - "chatglm2": ChatGLMForConditionalGeneration2, - "chatglm3": ChatGLMForConditionalGeneration3, - "code_llama": LlamaForCausalLM, - "deepseek": LlamaForCausalLM, - "gpt_neox": GPTNeoXForCausalLM, - "llama": LlamaForCausalLM, - "mistral": MistralForCausalLM, - "mixtral": MixtralForCausalLM, - "phi": PhiForCausalLM, - "qwen": QWenLMHeadModel, - "starcoder": GPTBigCodeForCausalLM, - "qwen2": Qwen2ForCausalLM, - "gemma": GemmaForCausalLM, - "qwen2_moe": Qwen2MoeForCausalLM, - "starcoder2": Starcoder2ForCausalLM, -} - -FULL_LORA_TARGETING_MODULES = { - "aquila": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "baichuan": ["W_pack", "o_proj", "gate_proj", "down_proj", "up_proj"], - "chatglm2": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - "chatglm3": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - "deepseek": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "code_llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "gpt_neox": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "mixtral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate", "w1", "w2", "w3"], - "phi": ["query_key_value", "dense", "fc1", "fc2"], - "qwen": ["c_proj", "c_attn", "w1", "w2"], - "starcoder": ["c_proj", "c_attn", "q_attn", "c_fc"], - "gemma": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "qwen2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "qwen2_moe": "all-linear", - "starcoder2": "all-linear", -} - - -MODEL_SPECIAL_TOKENS = { - "gpt_neox": { - "eos_token": "<|endoftext|>", - "pad_token": "<|pad|>", - }, - "llama": { - "eos_token": "", - "pad_token": "", - }, - "code_llama": { - "eos_token": "", - "pad_token": "", - }, - "baichuan": { - "eos_token": "", - "pad_token": "", - }, - "starcoder": { - "eos_token": "<|endoftext|>", - "pad_token": "", - }, - "qwen": { - "eos_token": "<|endoftext|>", - "pad_token": "<|extra_1|>", - }, - "chatglm2": { - "eos_token": "", - "pad_token": "", - }, - "chatglm3": { - "eos_token": "", - "pad_token": "", - }, - "phi": { - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - }, - "aquila": { - "eos_token": "", - "pad_token": "<|endoftext|>", - }, - "deepseek": { - "eos_token": "<|end▁of▁sentence|>", - "pad_token": "<|end▁of▁sentence|>", - }, - "mixtral": { - "eos_token": "", - "pad_token": "", - }, - "mistral": { - "eos_token": "", - "pad_token": "", - }, - "qwen2": { - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - }, - "gemma": { - "eos_token": "", - "pad_token": "", - }, - "qwen2_moe": { - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - }, - "starcoder2": { - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - }, -} diff --git a/mftcoder_accelerate/src/run_offline_tokenization.sh b/mftcoder_accelerate/src/run_offline_tokenization.sh new file mode 100644 index 0000000..ed916da --- /dev/null +++ b/mftcoder_accelerate/src/run_offline_tokenization.sh @@ -0,0 +1,13 @@ +MODEL_PATH= +DATA_PATH= +DATASET_NAME= +OUTPUT_PATH= + +python offline_tokenization/concat_sst_bin_tokenization.py \ +--model-path ${MODEL_PATH} \ +--data-path ${DATA_PATH} \ +--dataset-name ${DATASET_NAME} \ +--output-path ${OUTPUT_PATH} \ +--parallel 16 \ +--seq-length 4096 \ +--sample-percent 1.0 diff --git a/mftcoder_accelerate/src/tokenizer/__init__.py b/mftcoder_accelerate/src/tokenizer/__init__.py index 12ec210..20e88bb 100644 --- a/mftcoder_accelerate/src/tokenizer/__init__.py +++ b/mftcoder_accelerate/src/tokenizer/__init__.py @@ -1 +1,3 @@ from .tokenizer import build_tokenizer +from .tokenizer import init_tokenizer +from .chat_template import MFTCoder_template \ No newline at end of file diff --git a/mftcoder_accelerate/src/tokenizer/tokenizer.py b/mftcoder_accelerate/src/tokenizer/tokenizer.py index cacd712..bc3ab56 100644 --- a/mftcoder_accelerate/src/tokenizer/tokenizer.py +++ b/mftcoder_accelerate/src/tokenizer/tokenizer.py @@ -3,33 +3,71 @@ # @date 2023年6月19日 """ - import numpy as np from typing import List, Union from utils.common_utils import print_rank_0 -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoConfig from tokenizer.chat_template import MFTCoder_template +def init_tokenizer(path): + """ + Init a Huggingface tokenizer, parsing eos_token from the tokenizer_config then config. + Set pad_token same as eos_token for easy life. + :param path: model path or tokenizer path + :return: Tokenizer (TokenizerFast is preferred) + """ + # tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False, legacy=False) + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + config, unused_kwargs = AutoConfig.from_pretrained(path, trust_remote_code=True, return_unused_kwargs=True) + + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: + print(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer") + eos_token_id = tokenizer.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(eos_token_id) + elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token: + print(f"Initial eos_token {tokenizer.eos_token} from tokenizer") + eos_token = tokenizer.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) + elif hasattr(config, "eos_token_id") and config.eos_token_id: + print(f"Initial eos_token_id {config.eos_token_id} from config.json") + eos_token_id = config.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id) + elif hasattr(config, "eos_token") and config.eos_token: + print(f"Initial eos_token {config.eos_token} from config.json") + eos_token = config.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token) + else: + raise ValueError( + "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json" + ) + try: + tokenizer.eos_token = eos_token + tokenizer.eos_token_id = eos_token_id + # set pad_token to be same as eos_token, it is ok because is will be masked out. + tokenizer.pad_token = eos_token + tokenizer.pad_token_id = eos_token_id + except: + print(f"[WARNING]Cannot set tokenizer.eos_token") + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + tokenizer.chat_template = MFTCoder_template + print_rank_0(f"Tokenizer: {type(tokenizer)}") + print_rank_0(f"Length of tokenizer: {len(tokenizer)}") + print_rank_0(f"build_tokenizer pad_token_id: {tokenizer.pad_token_id}, eos_token_id: {tokenizer.eos_token_id}") + print_rank_0(f"build_tokenizer pad_token : {tokenizer.pad_token}, eos_token: {tokenizer.eos_token}") + + return tokenizer + + def build_tokenizer(args): """Initialize tokenizer.""" print_rank_0(f"> building {args.tokenizer_type} tokenizer ...") # Select and instantiate the tokenizer. if args.tokenizer_type.lower() == "AutoTokenizer".lower(): assert args.pretrained_model_path is not None - # tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True, use_fast=False, legacy=False) - tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True) - tokenizer.eod_id = tokenizer.convert_tokens_to_ids(args.eos_token) - tokenizer.pad_id = tokenizer.convert_tokens_to_ids(args.pad_token) - try: - tokenizer.eos_token = args.eos_token - tokenizer.pad_token = args.pad_token - except: - print(f"[WARNING]Cannot set tokenizer.eos_token") - print_rank_0(f"Tokenizer: {type(tokenizer)}") - print_rank_0(f"Length of tokenizer: {len(tokenizer)}") - print_rank_0(f"build_tokenizer PAD id: {tokenizer.pad_id}, EOD id: {tokenizer.eod_id}") - print_rank_0(f"build_tokenizer PAD token : {args.pad_token}, EOD token: {args.eos_token}") + tokenizer = init_tokenizer(args.pretrained_model_path) else: raise NotImplementedError(f"{args.tokenizer_type} tokenizer is not implemented.") diff --git a/mftcoder_accelerate/src/utils/loss_utils.py b/mftcoder_accelerate/src/utils/loss_utils.py index deb59d0..5ca7c73 100644 --- a/mftcoder_accelerate/src/utils/loss_utils.py +++ b/mftcoder_accelerate/src/utils/loss_utils.py @@ -67,14 +67,14 @@ def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_ ) # [B, L] task_mask_trans = torch.transpose(task_mask, 0, 1) unique_id = torch.unique(task_id) - if weighted_loss_mode == "case3" or weighted_loss_mode == "case4" or weighted_loss_mode == "selfpaced": + if weighted_loss_mode == "case3" or weighted_loss_mode == "case4" or weighted_loss_mode == "coba": loss = 0.0 weights_sum = 0.0 for i, w in enumerate(unique_id): row_idx = torch.squeeze(task_id) == w.item() task_weight = float(task_weights[w.item()]) weights_sum += task_weight - if weighted_loss_mode == "case3" or weighted_loss_mode == "selfpaced": + if weighted_loss_mode == "case3" or weighted_loss_mode == "coba": if loss_mask is None: loss += ( torch.sum(losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) * task_weight @@ -104,12 +104,12 @@ def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_ elif weighted_loss_mode == "case1": # flatten losses & loss_mask tensor if loss_mask is None: - losses = losses.view(-1) - loss = torch.sum(losses) / effective_tokens + # losses = losses.view(-1) + loss = torch.sum(losses.view(-1)) / effective_tokens else: - loss_mask = loss_mask.view(-1) - losses = losses.view(-1) - loss = torch.sum(losses * loss_mask) / loss_mask.sum() + # loss_mask = loss_mask.view(-1) + # losses = losses.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask.view(-1)) / loss_mask.view(-1).sum() # fix task order task_loss = torch.zeros(len(ID2TASK)).to(device=task_id.device) @@ -206,29 +206,35 @@ def __init__(self): super(MFTLossStatus, self).__init__() -class SelfpacedStatus(MFTLossStatus): +class CoBaStatus(MFTLossStatus): def __init__( self, - selfpaced_scale_factor=50, - selfpaced_interval=1, - selfpaced_history_length=100, - selfpaced_sample_valid_num=1, + coba_warmup_steps=100, + coba_history_length=200, + coba_tau=5, + coba_update_interval=1, + coba_sample_valid_num=1, valid_dataloader=None, ): - super(SelfpacedStatus, self).__init__() - self.selfpaced_scale_factor = selfpaced_scale_factor - self.selfpaced_interval = selfpaced_interval - self.selfpaced_history_length = selfpaced_history_length - self.selfpaced_sample_valid_num = selfpaced_sample_valid_num + super(CoBaStatus, self).__init__() + self.coba_warmup_steps = coba_warmup_steps + self.coba_history_length = coba_history_length + self.coba_tau = coba_tau + self.coba_update_interval = coba_update_interval + self.coba_sample_valid_num = coba_sample_valid_num self.valid_dataloader = valid_dataloader self.valid_dataloader_length = len(valid_dataloader) self.valid_iterator = iter(valid_dataloader) self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK)) - self.history_task_valid_loss = torch.zeros((selfpaced_history_length, len(ID2TASK))) + self.history_task_valid_loss = None + self.per_task_slope_list = None + self.total_slope_list = None + self.minimum_weight = 1 / (len(ID2TASK) * 10) + self.valid_task_loss_begining = torch.ones(len(ID2TASK), dtype=torch.float64) self.log_per_task_weight = torch.zeros(len(ID2TASK)) - def selfpaced_evaluate(self, model, v_batch, per_task_weight=None, selfpaced_status=None): + def coba_evaluate(self, model, v_batch, per_task_weight=None, coba_status=None): model.eval() with torch.no_grad(): valid_outputs = model( @@ -242,57 +248,93 @@ def selfpaced_evaluate(self, model, v_batch, per_task_weight=None, selfpaced_sta labels=v_batch["labels"], task_mask=v_batch["task_mask"], task_id=v_batch["task_id"], - weighted_loss_mode="selfpaced", + weighted_loss_mode="coba", loss_mask=v_batch["loss_mask"], task_weights=None, ) + task_exist = (valid_task_loss != 0.0).float() torch.distributed.all_reduce(valid_task_loss, op=torch.distributed.ReduceOp.SUM) - valid_task_loss /= torch.distributed.get_world_size() + torch.distributed.all_reduce(task_exist, op=torch.distributed.ReduceOp.SUM) + valid_task_loss /= task_exist.clamp_(1.0) + valid_task_loss /= self.valid_task_loss_begining model.train() return valid_task_loss def compute_per_task_weight(self, completed_steps=None): - task_slope_fitting = torch.ones(len(ID2TASK)) - history_steps = torch.arange( - completed_steps - self.selfpaced_history_length, completed_steps, 1 - ) # DEBUG: step < 0 - transpose_history_task_valid_loss = self.history_task_valid_loss.transpose(0, 1) - for i in range(len(ID2TASK)): - per_history_task_valid_loss = transpose_history_task_valid_loss[i] - task_slope_fitting[i] = self.fit_window_point( - history_steps, per_history_task_valid_loss, history=self.selfpaced_history_length, method="slope" + task_num = len(ID2TASK) + task_slope_fitting = torch.ones(task_num, dtype=torch.float64) + start_step = max(0, completed_steps // self.coba_update_interval - self.coba_history_length) + history_steps = torch.arange(start_step, completed_steps, 1) + for i in range(task_num): + per_task_history_valid_loss = self.history_task_valid_loss[i][-len(history_steps):] + task_slope_fitting[i] = self.fit_window_slope( + history_steps, per_task_history_valid_loss, type="slope" ) - slope_sum_abs = torch.sum(torch.abs(task_slope_fitting)) - - if slope_sum_abs == 0: - per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) + history_total_valid_loss, index = torch.max(self.history_task_valid_loss[:, -len(history_steps):], dim=0) + total_slope_fitting = self.fit_window_slope( + history_steps, history_total_valid_loss, type="slope" + ) + if completed_steps == self.coba_warmup_steps: + self.per_task_slope_list = task_slope_fitting.unsqueeze(1) + self.total_slope_list = total_slope_fitting.unsqueeze(0) else: - # print_rank_0(f"[step={completed_steps}][slope sum abs={slope_sum_abs}]") - normalize_slope = len(ID2TASK) * task_slope_fitting / slope_sum_abs - print_rank_0(f"normalize_slope: {normalize_slope}") - score = F.softmax(normalize_slope, dim=-1) * (-1 * normalize_slope) - print_rank_0(f"score: {score}") - per_task_weight = F.softmax(self.selfpaced_scale_factor * score, dim=-1) - print_rank_0(f"per_task_weight: {per_task_weight}") + self.per_task_slope_list = torch.cat((self.per_task_slope_list, task_slope_fitting.unsqueeze(1)), dim=-1) + self.total_slope_list = torch.cat((self.total_slope_list, total_slope_fitting.unsqueeze(0)), dim=0) + + # Relative Convergence Score + normalize_task_slope = task_num * task_slope_fitting / task_slope_fitting.abs().sum() + rcs = F.softmax(normalize_task_slope, dim=-1) + + # Absolute Convergence Score + history_per_task_slope_list = self.per_task_slope_list[:, start_step:] + reverse_normailize_iter_slope = -len(history_per_task_slope_list[0]) * history_per_task_slope_list \ + / history_per_task_slope_list.abs().sum(dim=-1, keepdim=True) + + flatten_rn_iter_slope = reverse_normailize_iter_slope.T.reshape(-1) + current_step_rn_slope = flatten_rn_iter_slope[-task_num:] + acs = F.softmax(current_step_rn_slope, dim=-1) + + # Divergence Factor + normalize_total_iter_slope = - len(self.total_slope_list) * self.total_slope_list \ + / self.total_slope_list.abs().sum() + divergence_factor = F.softmax(normalize_total_iter_slope * self.coba_tau, dim=-1)[-1] \ + * len(self.total_slope_list) + + weight_logits = divergence_factor * rcs + (1 - divergence_factor) * acs + per_task_weight = F.softmax(weight_logits * task_num, dim=-1) + + if len((per_task_weight < self.minimum_weight).nonzero().squeeze(0))> 0: + per_task_weight = per_task_weight * (1 - self.minimum_weight * task_num) + per_task_weight += self.minimum_weight return per_task_weight + + def fit_window_slope(self, x, y, type="slope"): - def fit_window_point(self, x, y, history=10, method="slope"): - + y = y[y != 0] + x = x[:len(y)] + nonzero_index = torch.squeeze(torch.nonzero(y), dim=1) y = torch.index_select(y, 0, nonzero_index) x = torch.index_select(x, 0, nonzero_index) ws = torch.flip(1 ** torch.arange(len(y)), dims=[0]) - ws = ws.float() + ws = ws.double() if len(y)>= 2: - if method == "slope": - X = torch.stack((x, torch.ones_like(x))).T - X = X.float() + if type == "slope": + X = torch.stack((x, torch.ones_like(x, dtype=torch.float64))).T + X = X.double() else: - X = torch.stack((x**2, x, torch.ones_like(x))).T + X = torch.stack((x ** 2, x, torch.ones_like(x, dtype=torch.float64))).T + + # implementation for numpy + # X_np = X.T @ (ws[:, None] * X) + # Y_np = X.T @ (ws * y) + # w = torch.from_numpy(np.linalg.solve(X_np.numpy(), Y_np.numpy())) + + # implementation for torch w = torch.linalg.solve(X.T @ (ws[:, None] * X), X.T @ (ws * y)) result = w[0] @@ -302,22 +344,22 @@ def fit_window_point(self, x, y, history=10, method="slope"): return result def sample_valid_batch(self, model, completed_steps): - self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK)) - for i in range(self.selfpaced_sample_valid_num): + self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK), dtype=torch.float64) + for i in range(self.coba_sample_valid_num): if ( - self.selfpaced_sample_valid_num * completed_steps // self.selfpaced_interval + i + self.coba_sample_valid_num * completed_steps // self.coba_update_interval + i ) % self.valid_dataloader_length == 0: self.valid_iterator = iter(self.valid_dataloader) - - v_batch = next(self.valid_iterator) - valid_task_loss = self.selfpaced_evaluate(model, v_batch) - self.valid_task_loss_accumulated += valid_task_loss.detach().cpu() - - self.valid_task_loss_accumulated /= self.selfpaced_sample_valid_num - self.history_task_valid_loss = torch.cat( - (self.history_task_valid_loss, torch.unsqueeze(self.valid_task_loss_accumulated, dim=0)) - ) - if len(self.history_task_valid_loss)> self.selfpaced_history_length: - self.history_task_valid_loss = self.history_task_valid_loss[ - len(self.history_task_valid_loss) - self.selfpaced_history_length : - ] + v_batch = next(self.valid_iterator) + else: + v_batch = next(self.valid_iterator) + valid_task_loss = self.coba_evaluate(model, v_batch) + self.valid_task_loss_accumulated += valid_task_loss.detach().cpu().double() + + self.valid_task_loss_accumulated /= self.coba_sample_valid_num + if self.history_task_valid_loss is None and completed_steps>= 1: + self.history_task_valid_loss = self.valid_task_loss_accumulated.unsqueeze(1) + elif self.history_task_valid_loss is not None: + self.history_task_valid_loss = torch.cat( + (self.history_task_valid_loss, self.valid_task_loss_accumulated.unsqueeze(1)), dim=-1 + ) diff --git a/mftcoder_accelerate/src/utils/model_mapping.py b/mftcoder_accelerate/src/utils/model_mapping.py new file mode 100644 index 0000000..8592e86 --- /dev/null +++ b/mftcoder_accelerate/src/utils/model_mapping.py @@ -0,0 +1,67 @@ +""" + @author qumu + transformers==4.40 is stable now +""" + +# Models that Transformers support Code and FA2 when flash_attn>=2.1.0 +from transformers import ( + GPTNeoXForCausalLM, + GPTBigCodeForCausalLM, + LlamaForCausalLM, + MistralForCausalLM, + MixtralForCausalLM, + PhiForCausalLM, + GemmaForCausalLM, + Qwen2ForCausalLM, + Qwen2MoeForCausalLM, + Starcoder2ForCausalLM, +) + +# model in local model dir and support transformers FA2 +from model.deepseek_v2.modeling_deepseek import DeepseekV2ForCausalLM + +# model in local model and self-contained +from model.aquila2.modeling_aquila import AquilaForCausalLM +from model.baichuan2.modeling_baichuan import BaichuanForCausalLM +from model.qwen.modeling_qwen import QWenLMHeadModel +from model.chatglm2.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration2 +from model.chatglm3.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration3 + +# from model.phi.modeling_mixformer_sequential import MixFormerSequentialForCausalLM + +MODEL_TYPES = { + "aquila2": AquilaForCausalLM, + "baichuan": BaichuanForCausalLM, + "chatglm2": ChatGLMForConditionalGeneration2, + "chatglm3": ChatGLMForConditionalGeneration3, + "code_llama": LlamaForCausalLM, + "deepseek": LlamaForCausalLM, + "gpt_neox": GPTNeoXForCausalLM, + "llama": LlamaForCausalLM, + "mistral": MistralForCausalLM, + "mixtral": MixtralForCausalLM, + "phi": PhiForCausalLM, + "qwen": QWenLMHeadModel, + "starcoder": GPTBigCodeForCausalLM, + "qwen2": Qwen2ForCausalLM, + "gemma": GemmaForCausalLM, + "qwen2_moe": Qwen2MoeForCausalLM, + "starcoder2": Starcoder2ForCausalLM, + "deepseek_v2": DeepseekV2ForCausalLM, +} + +SUPPORT_IN_TRANSFORMERS = [ + "code_llama", + "llama", + "deepseek", + "mistral", + "mixtral", + "gpt_neox", + "phi", + "starcoder", + "qwen2", + "qwen2_moe", + "gemma", + "starcoder2", + "deepseek_v2", +] diff --git a/mftcoder_accelerate/src/xxpo/custom_callbacks.py b/mftcoder_accelerate/src/xxpo/custom_callbacks.py new file mode 100644 index 0000000..f38fa70 --- /dev/null +++ b/mftcoder_accelerate/src/xxpo/custom_callbacks.py @@ -0,0 +1,99 @@ +""" +Customized Callbacks to use with the Trainer class and customize the training loop. +""" + +import copy +import dataclasses +import json +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import numpy as np +from tqdm.auto import tqdm + +from transformers.trainer_utils import IntervalStrategy, has_length +from transformers.training_args import TrainingArguments +from transformers.utils import logging +from transformers import TrainerCallback + +logger = logging.get_logger(__name__) + + +class CustomProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation. + """ + + def __init__(self): + self.training_bar = None + self.prediction_bar = None + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if state.is_world_process_zero and state.global_step % args.logging_steps == 0: + self.training_bar.update(args.logging_steps) + self.current_step = state.global_step + # pass + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + # if state.is_world_process_zero and has_length(eval_dataloader): + # if self.prediction_bar is None: + # self.prediction_bar = tqdm( + # total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True + # ) + # self.prediction_bar.update(1) + pass + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_predict(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_world_process_zero and self.training_bar is not None: + # avoid modifying the logs object as it is shared between callbacks + logs = copy.deepcopy(logs) + # _ = logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in logs: + logs["epoch"] = round(logs["epoch"], 2) + # self.training_bar.write(str(logs)) + logger.info(logs) + + def on_train_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.close() + self.training_bar = None + + +class PrinterCallback(TrainerCallback): + """ + A bare [`TrainerCallback`] that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + + +class LogCallback(TrainerCallback): + """ + A bare [`TrainerCallback`] that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + logger.info(logs) \ No newline at end of file diff --git a/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py b/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py new file mode 100644 index 0000000..4c93520 --- /dev/null +++ b/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py @@ -0,0 +1,484 @@ +""" +# @author qumu +# @date 2023年12月11日 +# @module mft_accelerate.py + +Accelerate + DeepSpeed/FSDP + QLoRA/LoRA/Full + DPO/RPO/ORPO + +Entry +""" + +import os +import sys +import argparse +import math +import logging +import json +import time +from datetime import timedelta +from tqdm.auto import tqdm +from dataclasses import dataclass +from typing import Dict, Optional, Union, List + +import datasets +from datasets import Dataset, load_dataset, concatenate_datasets + +import torch +from torch.utils.data import DataLoader +from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig + +import transformers +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + set_seed, + BitsAndBytesConfig, + get_scheduler, +) +from peft import ( + LoraConfig, + TaskType, + get_peft_model, + prepare_model_for_kbit_training, + PeftModel, +) +from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration +from accelerate.logging import get_logger +from accelerate.utils import InitProcessGroupKwargs + +# insert src as import path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +sys.path.insert(0, parent_dir) + +from tokenizer import build_tokenizer + +from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK +from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS + +logger = get_logger(__name__) + + +from trl import ( + DPOConfig, + DPOTrainer, + ORPOConfig, + ORPOTrainer, + ModelConfig, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from transformers.trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) + +from xxpo.xxpo_arguments import XXPOTrainArgs +from xxpo.custom_callbacks import CustomProgressCallback +from xxpo.custom_callbacks import LogCallback + + +def pprint_args(args, accelerator): + # 计算所有键的最大字符串长度 + max_key_length = max(len(str(key)) for key in vars(args).keys()) + + message = "" + message += "====" * 60 + "\n" + message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n" + message += "====" * 60 + "\n" + accelerator.print(message) + accelerator.print("GPU: {}".format(torch.cuda.current_device())) + + +def prepare_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_config", type=str, default=None) + + parser.add_argument("--data_paths", type=str, default=None) + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--tb_dir", type=str, default=None) + parser.add_argument("--pretrained_model_path", type=str, default=None) + parser.add_argument("--micro_batch_size", type=int, default=None) + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--distributed_type", type=str, default="deepspeed") + + parsed = parser.parse_args() + # get json configs + with open(parsed.train_config, "r") as f: + train_config = json.load(f) + + # parse args from cofig.json + args = XXPOTrainArgs(**train_config) + + # override args by cli arguments + if parsed.data_paths: + args.data_paths = parsed.data_paths + if parsed.output_dir: + args.output_dir = parsed.output_dir + if parsed.tb_dir: + args.tb_dir = parsed.tb_dir + if parsed.pretrained_model_path: + args.pretrained_model_path = parsed.pretrained_model_path + args.vocab_file = parsed.pretrained_model_path + if parsed.micro_batch_size: + args.per_device_train_batch_size = parsed.micro_batch_size + args.per_device_eval_batch_size = parsed.micro_batch_size + if parsed.model_type: + args.model_type = parsed.model_type + + args.distributed_type = parsed.distributed_type + + # refactor args + + if args.peft_type == "qlora": + print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'") + args.quantization = "4bit" + else: + args.quantization = None + + args.vocab_file = args.pretrained_model_path + + return args + + +def get_model(args, accelerator): + ModelClass = MODEL_TYPES[args.model_type] + if args.model_type in SUPPORT_IN_TRANSFORMERS: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + attn_implementation=args.attn_implementation, + torch_dtype=torch.bfloat16, + # device_map=get_kbit_device_map() if args.quantization == "4bit" else None, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), + ) + else: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + torch_dtype=torch.bfloat16, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), + ) + + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + + return model + + +def chatml_to_dpo_format( + data_file: str, + tokenizer, + sanity_check: bool = False, + cache_dir: Optional[str] = None, + num_proc=16, +) -> Dataset: + """Load the standard-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'chosen': List[dict], chatml + 'rejected': List[dict], chatml + } + """ + + dataset = load_dataset( + "json", + split="train", + data_files=data_file, + cache_dir=cache_dir, + verification_mode="no_checks", + ) + original_columns = dataset.column_names + + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 100))) + + def process(samples): + samples["prompt"] = [ + tokenizer.apply_chat_template(chosen[:-1], tokenize=False, add_generation_prompt=True) + for chosen in samples["chosen"] + ] + samples["chosen"] = [chosen[-1]["content"] + tokenizer.eos_token for chosen in samples["chosen"]] + samples["rejected"] = [rejected[-1]["content"] + tokenizer.eos_token for rejected in samples["rejected"]] + return samples + + return dataset.map( + process, + batched=True, + num_proc=num_proc, + # remove_columns=original_columns, + ) + + +def main(): + t0 = time.time() + # os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["HF_HUB_OFFLINE"] = "false" + # get input args, set TASK2ID, ID2TASK, refactor args + args = prepare_args() + + # fix randomness + if args.seed is not None: + set_seed(args.seed) + + # define accelerator + init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) + + if args.distributed_type and args.distributed_type.lower() == "fsdp": + fsdp_plugin = FullyShardedDataParallelPlugin( + # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + limit_all_gathers=True, + sync_module_states=True, + use_orig_params=True, + cpu_offload=False, + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + else: + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + + # print key infos + accelerator.print("In dpo_accelerate.py, sys path:", sys.path) + accelerator.print(f"transformers.__version__: {transformers.__version__}") + + # get world_size + args.world_size = accelerator.num_processes + + # backup args + pprint_args(args, accelerator) + if accelerator.is_main_process: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "args.json"), "w") as f: + json.dump(args.dict(), f, indent=2) + + # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest + + # logger + logging.basicConfig( + format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # get global_rank and local rank for current process + global_rank = accelerator.process_index + local_rank = accelerator.local_process_index + print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}") + + # 1. dataset + + # build tokenizer + tokenizer = build_tokenizer(args) + # tokenizer.chat_template = MFTCoder_template + + # Load the dpo dataset + all_datasets = [] + # print(args.data_paths, type(args.data_paths)) + if isinstance(args.data_paths, str): + args.data_paths = list(args.data_paths[1:-1].split(",")) + # print(f"DATA_PATHS: {args.data_paths}") + for data_file in args.data_paths: + ds = chatml_to_dpo_format(data_file=data_file, tokenizer=tokenizer, sanity_check=args.sanity_check) + all_datasets.append(ds) + + all_dataset = concatenate_datasets(all_datasets) + # all_dataset = all_dataset.filter( + # lambda x: len(x["prompt"]) + len(x["chosen"]) <= args.max_length + # and len(x["prompt"]) + len(x["rejected"]) <= args.max_length + # ) + accelerator.print(f"Length of all_dataset: {len(all_dataset)}") + + # split train/eval dataset + splits = [float(s) for s in args.data_split.split(",")][:2] + print(f"data splits: {splits}") + + all_dataset = all_dataset.train_test_split(test_size=splits[1] / sum(splits), shuffle=True, seed=args.seed) + all_dataset.flatten_indices() + + train_dataset, eval_dataset = all_dataset["train"], all_dataset["test"] + accelerator.print(f"Length of train_dataset: {len(train_dataset)}\nLength of eval_dataset: {len(eval_dataset)}") + print(eval_dataset[0]) + t1 = time.time() + logger.info(f"dataset loading time: {t1 - t0:.4f}") + + # cuda memory + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + max_memory = f"{free_in_GB - 2}GB" + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory for i in range(n_gpus)} + accelerator.print("max memory: ", max_memory, n_gpus) + + # target_modules, default all-linear for all linear layers + if args.target_modules: + target_modules = args.target_modules + else: + target_modules = "all-linear" + + # peft config + if args.peft_type: + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=target_modules, + bias="lora_only", + ) + else: + peft_config = None + + # creating base model + model = get_model(args, accelerator) + if args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + accelerator.print("Model load_in_4bit: ", args.quantization == "4bit") + + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + if hasattr(model.config, "use_logn_attn"): + model.config.use_logn_attn = False # special for qwen model + # load balance for moe training + if hasattr(model.config, "output_router_logits"): + model.config.output_router_logits = True + model_config = model.config + accelerator.print(model.config) + + t2 = time.time() + if accelerator.is_main_process: + logging.info(f"model loading time: {t2 - t1:.4f}") + + # 4. initialize training arguments: + if args.xxpo == "dpo": + ConfigClass = DPOConfig + elif args.xxpo == "orpo": + ConfigClass = ORPOConfig + logging.info(f"{args.xxpo} Used.") + + training_args = ConfigClass( + beta=args.beta, + rpo_alpha=args.rpo_alpha, + per_device_train_batch_size=args.per_device_train_batch_size, + per_device_eval_batch_size=args.per_device_eval_batch_size, + max_steps=args.max_steps, + num_train_epochs=args.num_train_epochs, + logging_steps=args.logging_steps, + save_strategy="steps", + eval_strategy="steps", + save_steps=args.save_steps, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_checkpointing=args.gradient_checkpointing, + learning_rate=args.learning_rate, + eval_steps=args.eval_steps, + output_dir=args.output_dir, + report_to="tensorboard", + logging_dir=args.tb_dir, + max_prompt_length=args.max_prompt_length, + max_length=args.max_length, + lr_scheduler_type=args.lr_scheduler_type, + warmup_steps=args.warmup_steps, + optim=args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="", + gradient_checkpointing_kwargs=dict(use_reentrant=args.gradient_checkpointing_use_reentrant), + seed=args.seed, + dataset_num_proc=args.dataset_num_proc, + disable_tqdm=args.disable_tqdm, + save_only_model=args.save_only_model, + save_total_limit=args.saving_limit, + ) + + # 5. initialize the DPO trainer + if not args.peft_type and args.xxpo == "dpo": + model_ref = get_model(args, accelerator) + model_ref.config.use_cache = False # silence the warnings. Please re-enable for inference! + else: + model_ref = None + + if args.xxpo == "dpo": + xxpo_trainer = DPOTrainer( + model, + ref_model=model_ref, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + ) + elif args.xxpo == "orpo": + xxpo_trainer = ORPOTrainer( + model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + ) + + # callbacks + if args.disable_tqdm: + xxpo_trainer.remove_callback(PrinterCallback) + xxpo_trainer.add_callback(LogCallback) + else: + xxpo_trainer.remove_callback(ProgressCallback) + xxpo_trainer.add_callback(CustomProgressCallback) + + # 6. train + xxpo_trainer.train() + + # 7. save + output_dir = os.path.join(args.output_dir, "epoch_final") + xxpo_trainer.save_model(output_dir) + # dpo_trainer.model.save_pretrained(output_dir) + logger.info(f"Training Finished!") + + +if __name__ == "__main__": + main() diff --git a/mftcoder_accelerate/src/xxpo/xxpo_arguments.py b/mftcoder_accelerate/src/xxpo/xxpo_arguments.py new file mode 100644 index 0000000..2b4c876 --- /dev/null +++ b/mftcoder_accelerate/src/xxpo/xxpo_arguments.py @@ -0,0 +1,170 @@ +""" +# @author Chaoyu Chen +# @date 2023/10/19 + +training arguments +""" + +from dataclasses import dataclass, asdict +from typing import List, Union + + +@dataclass +class XXPOTrainArgs: + # train data paths on shared FS + data_paths: Union[str, List[str]] + + # output dir for saving adaptors in peft or full ckpts in full-parameter training + output_dir: str + + # tensorboard dir for saving tensorboard logs + tb_dir: str + + # pretrained_model_path, on which is the model you want to train + pretrained_model_path: str + + # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2 + model_type: str + + # train/valid/test split + data_split: str = "98,2,0" + + # lora or qlora or None(for full-parameter training) + peft_type: Union[None, str] = "qlora" + + # if qlora, 4bit will be set, else None + quantization: Union[None, str] = "4bit" + + # lora rank, the bigger, the more trainalbe parameters + lora_rank: int = 96 + + # lora alpha + lora_alpha: int = 32 + + # lora dropout + lora_dropout: float = 0.05 + + # lora targeting modules + target_modules: Union[None, str, List[str]] = None + + # dpo or orpo + xxpo: str = "dpo" + + # dpo/orpo beta + beta: float = 0.1 + + rpo_alpha: Union[None, float] = None + + # mircro train batch size + per_device_train_batch_size: int = 8 + + # micro eval batch size, always same as micro train batch size + per_device_eval_batch_size: int = 8 + + # HF AutoTokenizer is supported, maybe more types + tokenizer_type: str = "AutoTokenizer" + + # initial lr + learning_rate: float = 5e-5 + + # minimum lr + min_lr: float = 5e-6 + + # weight decay + weight_decay: float = 0.01 + + # gradient_accumulation_steps + gradient_accumulation_steps: int = 1 + + # lr_scheduler_type + lr_scheduler_type: str = "cosine" + + # optimizer_type + optimizer_type: str = "adamw_torch" + # optimizer_type: str = "paged_adamw_32bit" + + # gradient_checkpointing + gradient_checkpointing: bool = True + gradient_checkpointing_use_reentrant: bool = False + + # num of warmup_steps + warmup_steps: Union[int, float] = 0.05 + + # num_train_epochs + num_train_epochs: int = 4 + + # seed for reproducing + seed: int = 1234 + + # seq_length, context length + seq_length: int = 4096 + + save_only_model: bool = True + + # path of adaptor which is resumed from, None for not resuming training + resume_from_checkpoint: Union[None, str] = None + + # auto resume from latest ckpt if job restarted + auto_resume: bool = True + + # num of steps for logging training loss + logging_steps: int = 10 + + # num of steps for saving ckpt + save_steps: int = 100 + + # num of steps for evaluation(eval_loss), better same as checkpointing steps + eval_steps: int = 100 + + # max train steps, if None, depends on num_train_epochs + max_steps: int = -1 + + # if checkpointing every epoch, maybe True in sst + epoch_checkpointing: bool = False + + # shuffle before train/valid split + shuffle_before_split: bool = True + + # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point + early_stopping: bool = True + early_stopping_stall_num: int = 5 + + # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota. + saving_limit: Union[None, int] = None + + # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} + attn_implementation: str = "flash_attention_2" + + # tokenizer chat template, if None, will use MFTCoder template + chat_template: Union[None, str] = None + + distributed_type: Union[None, str] = None + + init_timeout_seconds: Union[None, int] = 3600 + + make_vocab_size_divisible_by: int = 32 + model_parallel_size: int = 1 + use_slow_tokenizer: bool = False + world_size: int = 8 + + # max prompt string length and whole str length + max_prompt_length: Union[None, int] = 2048 + max_length: Union[None, int] = 4096 + + # num of process processing dataset + dataset_num_proc: int = 1 + + # model_dtype[float16, bfloat16, float] for loading + dtype: str = "bfloat16" + + # instrumentation + disable_tqdm: bool = False + sanity_check: bool = False + + # debug argument for distributed training + # "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + # "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + ignore_bias_buffers: bool = True + + def dict(self): + return {k: str(v) for k, v in asdict(self).items()} diff --git a/requirements.txt b/requirements.txt index c6430fc..189518b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ numpy==1.23.5 -pandas==1.5.3 +pandas==2.2.1 torch==2.1.0 tensorboard==2.11.0 deepspeed==0.14.0 -transformers==4.40.2 -accelerate==0.28.0 +transformers==4.44.2 +accelerate==0.31.0 peft==0.10.0 BitsAndBytes==0.43.0 xformers==0.0.22.post7 +datasets +ftfy packaging einops sentencepiece

AltStyle によって変換されたページ (->オリジナル) /