diff --git a/.gitignore b/.gitignore index 68729d6..2e34447 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .idea/ -.DS_Store \ No newline at end of file +.DS_Store +*.log +*/__pycache__/ +*.pyc \ No newline at end of file diff --git a/README.md b/README.md index 94aa7c9..468cc88 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning +# MFTCoder: High Accuracy and Efficiency Multi-task Fine-Tuning Framework

@@ -21,7 +21,11 @@ Open Issues

- +

+ 🤗 HuggingFace + • 🤖 ModelScope + +

[[中文]](README_cn.md) [**English**] @@ -38,111 +42,143 @@ - [Models](#Models) - [Datasets](#Datasets) - [Star History](#Star-History) +- [Join Us](#Join-Us) ## News +🔥🔥🔥 [2024/10/31] We released **MFTCoder v0.5** mainly for MFTCoder-accelerate, which is now supporting preference alignment methods like **DPO/RPO/ORPO** in the new **xxpo** module, adding full-parameter continue-training in the additional **mpt** module along with its **offline_tokenization** module, updating selfpaced method to new convergence balance(CoBa) method for MFT in the original **pefts** module. + +🔥🔥🔥 [2024/10/31] Our paper [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) has been accepted by EMNLP-2024, which achieves balanced convergence across various tasks. + +🔥🔥🔥 [2024/05/20] We released **MFTCoder v0.4**, mainly for MFTCoder-accelerate. It supports **QLoRA + DeepSpeed Zero3** and **QLoRA + FSDP** as options allowing you training very large models. It now supports new models like Qwen2, Qwen2-MoE, Starcoder2, Gemma, etc. + +🔥🔥🔥 [2024/05/20] Our paper [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) has been accepted by KDD2024. + +🔥🔥🔥 [2024/05/20] [CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) has been released, achieving a pass@1 (greedy decoding) score of 73.2% on HumanEval. + +🔥🔥 [2024/01/30] The model [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) fine-tuned with MFTCoder ranks first in HuggingFace [Big Code Models LeaderBoard](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard) + +🔥🔥 [2024/01/17] We released MFTCoder v0.3.0, mainly for MFTCoder-accelerate. It now supports new models like Mixtral(MoE), DeepSeek-coder, chatglm3. It supports FSDP as an option. It also supports Self-paced Loss as a solution for convergence balance in Multitask Fine-tuning. + +🔥🔥 [2024/01/17] [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) has been released, achieving a pass@1 (greedy decoding) score of 78.7% on HumanEval. It lists as top-1 LLM on Bigcode Leardboard in terms of win-rate, the official result is going to be published later. + +🔥🔥 [2024/01/17] [CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8X7B) has been released, achieving a pass@1 (greedy decoding) score of 56.1% on HumanEval. + 🔥🔥 [2023/11/07] [MFTCoder Paper](https://arxiv.org/abs/2311.02303) has been released on Arxiv, which discloses technique details of multi-task-fine-tuning. 🔥🔥 [2023/10/20] [CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) has been released, achieving a pass@1 (greedy decoding) score of 48.8% on HumanEval, which gains 16% absolute improvement over the base model [Qwen-14b](https://huggingface.co/Qwen/Qwen-14B) 🔥🔥 [2023/09/27] [CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) has been released, achieving a pass@1 (greedy decoding) score of 54.9% on HumanEval. -🔥🔥🔥 [2023/09/26]We are pleased to announce the release of the [4-bit quantized version of CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits). Despite the quantization process, the model still achieves a remarkable 73.8% accuracy (greedy decoding) on the HumanEval pass@1 metric. +🔥🔥 [2023/09/26]We are pleased to announce the release of the [4-bit quantized version of CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits). Despite the quantization process, the model still achieves a remarkable 73.8% accuracy (greedy decoding) on the HumanEval pass@1 metric. -🔥🔥🔥 [2023/09/07]We released **CodeFuse-CodeLlama-34B**, which achieves the **74.4% Python Pass@1** (greedy decoding) and surpasses GPT4 (2023/03/15) and ChatGPT-3.5 on the [HumanEval Benchmarks](https://github.com/openai/human-eval). +🔥🔥 [2023/09/07]We released [**CodeFuse-CodeLlama-34B**](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits), which achieves the **74.4% Python Pass@1** (greedy decoding) and surpasses GPT4 (2023/03/15) and ChatGPT-3.5 on the [HumanEval Benchmarks](https://github.com/openai/human-eval). -🔥🔥 [2023/08/26]We released MFTCoder which supports finetuning Code Llama, Llama, Llama2, StarCoder, ChatGLM2, CodeGeeX2, Qwen, and GPT-NeoX models with LoRA/QLoRA. +🔥🔥 [2023/08/26]We released MFTCoder-v0.1.0 which supports finetuning Code Llama, Llama, Llama2, StarCoder, ChatGLM2, CodeGeeX2, Qwen, and GPT-NeoX models with LoRA/QLoRA. ### HumanEval Performance | Model | HumanEval(Pass@1) | Date | |:----------------------------|:-----------------:|:-------:| -| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | -|**CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | -| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | -| GPT-4(zero-shot) | 67.0% | 2023/03 | -| PanGu-Coder2 15B | 61.6% | 2023/08 | -| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | -| CodeLlama-34b-Python | 53.7% | 2023/08 | -| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | -| CodeLlama-34b | 48.8% | 2023/08 | -| GPT-3.5(zero-shot) | 48.1% | 2022/11 | -| OctoCoder | 46.2% | 2023/08 | -| StarCoder-15B | 33.6% | 2023/05 | -| QWen-14B | 32.3% | 2023/10 | +| **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 | +| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | +| **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | +| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 | +| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | +| GPT-4(zero-shot) | 67.0% | 2023/03 | +| PanGu-Coder2 15B | 61.6% | 2023/08 | +| **CodeFuse-Mixtral-8x7B** | **56.1%** | 2024/01 | +| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | +| CodeLlama-34b-Python | 53.7% | 2023/08 | +| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | +| CodeLlama-34b | 48.8% | 2023/08 | +| GPT-3.5(zero-shot) | 48.1% | 2022/11 | +| OctoCoder | 46.2% | 2023/08 | +| StarCoder-15B | 33.6% | 2023/05 | +| QWen-14B | 32.3% | 2023/10 | ## Articles -[MFT Arxiv paper](https://arxiv.org/abs/2311.02303) +[MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning (KDD2024)](https://arxiv.org/abs/2311.02303) ## Introduction -**High Accuracy and efficiency multi-task fine-tuning framework for Code LLMs.** +**High Accuracy and efficiency Multi-task Fine-tuning framework for Code LLMs.** + +**MFTCoder** is an open-source project of CodeFuse for accurate and efficient Multi-task Fine-tuning(MFT) on Large Language Models(LLMs), especially on Code-LLMs(large language model for code tasks). +Moreover, we open source Code LLM models and code-related datasets along with the MFTCoder framework. -**CodeFuse-MFTCoder** is an open-source project of CodeFuse for multitasking Code-LLMs(large language model for code tasks), which includes models, datasets, training codebases and inference guides. In MFTCoder, we released two codebases for finetuning Large Language Models: -- ```mft_peft_hf``` is based on the HuggingFace Accelerate and deepspeed framework. -- ```mft_atorch``` is based on the [ATorch frameworks](https://github.com/intelligent-machine-learning/dlrover), which is a fast distributed training framework of LLM. +- **```MFTCoder-accelerate```** is a framework with accelerate and DeepSpeed/FSDP. All tech-stacks are open-source and vibrant. We highly recommend you try this framework and make your fintuning accurate and efficient. +- ```MFTCoder-atorch``` is based on the [ATorch frameworks](https://github.com/intelligent-machine-learning/dlrover), which is a fast distributed training framework of LLM. The aim of this project is to foster collaboration and share advancements in large language models, particularly within the domain of code development. ### Frameworks -![img.png](./assets/img.png) +![img.jpg](./assets/img.jpg) ### Highlights :white_check_mark: **Multi-task**: Train models on multiple tasks while maintaining a balance between them. The models can even generalize to new, previously unseen tasks. :white_check_mark: **Multi-model**: It integrates state-of-the-art open-source models such as gpt-neox, llama, llama-2, baichuan, Qwen, chatglm2, and more. (These finetuned models will be released in the near future.) -:white_check_mark: **Multi-framework**: It provides support for both HuggingFace Accelerate (with deepspeed) and [ATorch](https://github.com/intelligent-machine-learning/dlrover). +:white_check_mark: **Multi-framework**: It provides support for both Accelerate (with Deepspeed and FSDP) and ATorch -:white_check_mark: **Efficient fine-tuning**: It supports LoRA and QLoRA, enabling fine-tuning of large models with minimal resources. The training speed meets the demands of almost all fine-tuning scenarios. +:white_check_mark: **Efficient fine-tuning**: It supports LoRA, QLoRA as well as Full-parameters training, enabling fine-tuning of large models with minimal resources. The training speed meets the demands of almost all fine-tuning scenarios. The main components of this project include: - Support for both SFT (Supervised FineTuning) and MFT (Multi-task FineTuning). The current MFTCoder achieves data balance among multiple tasks, and future releases will achieve a balance between task difficulty and convergence speed during training. -- Support for QLoRA instruction fine-tuning, as well as LoRA fine-tuning. -- Support for most mainstream open-source large models, particularly those relevant to Code-LLMs, such as Code-LLaMA, Starcoder, Codegeex2, Qwen, GPT-Neox, and more. +- Support for QLoRA instruction fine-tuning, LoRA fine-tuning as well as Full-parameters fine-tuning. +- Support for most mainstream open-source large models, particularly those relevant to Code-LLMs, such as DeepSeek-coder, Mistral, Mixtral, Chatglm3, Code-LLaMA, Starcoder, Codegeex2, Qwen, GPT-Neox, and more. - Support for weight merging between the LoRA adaptor and base models, simplifying the inference process. - Release of 2 high-quality code-related instruction fine-tuning datasets: [Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) and [CodeExercise-Python-27k](https://huggingface.co/datasets/codefuse-ai/CodeExercise-Python-27k). -- Release of 2 models: [CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B) and [CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B). +- Release of many Code LLMs, please refer to organizations: [codefuse-ai on huggingface](https://huggingface.co/codefuse-ai) or [codefuse-ai on modelscope](https://modelscope.cn/organization/codefuse-ai). ## Requirements -To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 11.7) along with the necessary drivers. Additionally, make sure you have installed torch (version 2.0.1). +To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 12.1) along with the necessary drivers. Additionally, make sure you have installed torch (version >= 2.1.0). Next, we have provided an init_env.sh script to simplify the installation of required packages. Execute the following command to run the script: ```bash sh init_env.sh ``` -If you require flash attention, please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention +We highly recommend training with flash attention(version >= 2.3.0), please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention ## Training -🚀 [Huggingface accelerate + deepspeed Codebase for MFT(Multi-task Finetuning)](./mft_peft_hf/README.md) +As mentioned above, we open source two training frameworks. You could refer to their own READMEs for more details as followed. -🚀 [Atorch Codebase for MFT(Multi-task Finetuning)](./mft_atorch/README.md) +If you are familiar with open source ```transformers```, ```DeepSpeed``` or ```FSDP```, we highly recommend you try: +🚀🚀 [**MFTCoder-accelerate: Accelerate + Deepspeed/FSDP Codebase for MFT(Multi-task Finetuning)**](mftcoder_accelerate/README.md) -## Models -We are excited to release the following two CodeLLMs trained by MFTCoder, now available on Hugging Face: +If you want to explore some new framework like atorch, you could check: + +🚀 [MFTCoder-atorch: Atorch Codebase for MFT(Multi-task Finetuning)](mftcoder_atorch/README.md) -| Model | Base Model | Num of examples trained | Batch Size | Seq Length | -|--------------------------------------------------------------------------------------------|--------------------|-------------------------|------------|------------| -| [🔥🔥🔥 CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 600k | 80 | 4096 | -| [🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python| | | 4096 | -| [🔥🔥🔥 CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | Starcoder | 600k | 256 | 4096 | -| [🔥🔥🔥 CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 1100k | 256 | 4096 | -| [🔥 CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B) | CodeFuse-13B | 66k | 64 | 4096 | +## Models +We are excited to release the following two CodeLLMs trained by MFTCoder, now available on both HuggingFace and ModelScope: +| Model | HuggingFace Links | ModelScope Links | Base Model | Num of examples trained | Batch Size | Seq Length | +|----------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|-------------------------|------------|------------| +| 🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 600K | 80 | 4096 | +| 🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 600K | 80 | 4096 | +| 🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 600K | 80 | 4096 | +| 🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | +| 🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 600K | 80 | 4096 | +| 🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 1.1 Million | 256 | 4096 | +| 🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 1.1 Million | 256 | 4096 | +| 🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 700K | 128 | 4096 | + ## Datasets We are also pleased to release two code-related instruction datasets, meticulously selected from a range of datasets to facilitate multitask training. Moving forward, we are committed to releasing additional instruction datasets covering various code-related tasks. | Dataset | Description | |-----------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [⭐ Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) | Based on open-evol-instruction-80k, filter out low-quality, repeated, and similar instructions to HumanEval, thus get high-quality code instruction dataset. | +| [⭐ Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) | Based on open-evol-instruction-80k, filter out low-quality, repeated, and similar instructions to HumanEval, thus get high-quality code instruction dataset. | | [⭐ CodeExercise-Python-27k](https://huggingface.co/datasets/codefuse-ai/CodeExercise-Python-27k) | python code exercise instruction dataset | @@ -171,3 +207,13 @@ If you find our work useful or helpful for your R&D works, please feel free to c +## Join-US + +We are the AI Native team within the Platform Technology Business Group at Ant Group, dedicated to the intelligentization of Ant Group's platform engineering. Established for over three years, our team has played a pivotal role in supporting the intelligent operation and maintenance of Ant Group's cloud computing infrastructure. Our mission is to build algorithm services and platforms with a wide user base through world-class technological innovation and impact, supporting the implementation of internal and external products and businesses. +Embracing an innovation-driven ethos, our team not only supports business implementation but also propels technological influence. Over the past three years, we have published more than 20 papers at top conferences like ICLR, NeurIPS, KDD, and ACL. Our innovative business outcomes have earned us two Ant Technology's highest T-Star awards and one SuperMA award from Ant Group. Our open-source project CodeFuse has received 4K stars as of February 2024, and our models have been downloaded over 1.5 million times on Huggingface and Modelscope. + +**We are on the lookout for top talents to join our vibrant team! If you're eager to develop your career in an environment filled with energy, innovation, and a culture of excellence, we welcome you to explore our career opportunities for both campus and experienced hires. Join us and be a part of creating the next milestone in the industry.** + +**Campus Recruitment**: https://hrrecommend.antgroup.com/guide.html?code=8uoP5mlus5DqQYbE_EnqcE2FD5JZH21MwvMUIb9mb6X3osXPuBraG54SyM8GLn_7 + +**Experienced Hires**: https://talent.antgroup.com/off-campus-position?positionId=1933830 diff --git a/README_cn.md b/README_cn.md index 04a8aff..3102d9f 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,4 +1,4 @@ -# MFTCoder: 多任务大模型代码能力微调框架 +# MFTCoder: 高效准确的多任务大模型微调框架

@@ -22,6 +22,11 @@

+

+ 🤗 HuggingFace + • 🤖 魔搭 +

+ [**中文**] [[English]](README.md) @@ -36,9 +41,25 @@ - [训练](#训练) - [模型](#模型) - [数据集](#数据集) +- [加入我们](#加入我们) ## 新闻 +🔥🔥🔥 [2024/10/31] **MFTCoder-v0.5**发布,新增**xxpo**模块支持偏好对齐DPO/RPO/ORPO;新增**mpt**和**offline_tokenization**模块支持全量参数的加训;在原本的**pefts**模块(MFT)更新selfpaced收敛均衡技术并更名CoBa。 + +🔥🔥🔥 [2024/10/31] 我们的论文 [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) 已被 EMNLP 2024 接收,可以实现多任务收敛均衡。 + +🔥🔥🔥 [2024/05/20] **MFTCoder-v0.4**发布。新增支持**QLoRA+ DeepSpeed Zero3**, **QLoRA + FSDP**训练模式,可以更好的支持微调更大的模型,比如Qwen1.5-70B等。新增对Qwen2, Qwen2-MoE, Starcoder2, Gemma等模型的支持。 + +🔥🔥🔥 [2024/05/20] 我们的论文 [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) 已被 KDD 2024 接收. + +🔥🔥🔥 开源了[CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B)模型,在HumanEval上可以达到73.2%,多代码语言能力均衡. + +🔥🔥 [2024/01/17] **MFTCoder-v0.3.0**发布。新增对Mixtral(MoE), DeepSeek等模型的支持;新增支持FSDP(Fully Sharded Data Parallel);新增Self-paced Loss, 支持多任务收敛均衡。 感兴趣详见微信公众号CodeFuse的文章[MFTCoder 重磅升级v0.3.0发布](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) + +🔥🔥 [2024/01/17] 开源了[CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B)模型,在HumanEval pass@1(greedy decoding)上可以达到78.7%。该模型在Big Code榜单的结果近期发布,请关注公众号获取最新信息。 + +🔥🔥 [2024/01/17] 开源了[CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B)模型,在HumanEval pass@1(greedy decoding)上可以达到56.1%。感兴趣详见微信公众号CodeFuse的文章[MFTCoder提升Mixtral-8x7B混合专家模型的代码能力实践](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) 🔥🔥 [2023/11/07] [MFTCoder论文](https://arxiv.org/abs/2311.02303)在Arxiv公布,介绍了多任务微调的技术细节。 @@ -46,28 +67,31 @@ 🔥🔥 [2023/09/27] 开源了[CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B)模型,在HumanEval pass@1(greedy decoding)上可以达到54.9%。 -🔥🔥🔥 [2023/09/26] [CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits)量化版本发布,量化后模型在HumanEval pass@1指标为73.8% (贪婪解码)。 +🔥🔥 [2023/09/26] [CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits)量化版本发布,量化后模型在HumanEval pass@1指标为73.8% (贪婪解码)。 -🔥🔥🔥 [2023/09/07]MFTCoder微调的模型**CodeFuse-CodeLlama-34B**在[HumanEval Benchmarks](https://github.com/openai/human-eval)的Python **Pass@1** 取得了**74.4%**(greedy decoding)的开源SOTA成绩。 +🔥🔥 [2023/09/07]MFTCoder微调的模型**CodeFuse-CodeLlama-34B**在[HumanEval Benchmarks](https://github.com/openai/human-eval)的Python **Pass@1** 取得了**74.4%**(greedy decoding)的开源SOTA成绩。 -🔥 [2023/08/26]MFTCoder支持使用LoRA/QLoRA对Code Llama、Llama、Llama2、StarCoder、ChatGLM2、CodeGeeX2、Qwen和GPT-NeoX模型进行微调。 +🔥🔥 [2023/08/26]MFTCoder-v0.1.0 支持使用LoRA/QLoRA对Code Llama、Llama、Llama2、StarCoder、ChatGLM2、CodeGeeX2、Qwen和GPT-NeoX模型进行微调。 ### HumanEval表现 -| 模型 | HumanEval(Pass@1) | 日期 | -|:----------------------------|:-----------------:|:-------:| -| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | -|**CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | -| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | -| GPT-4(zero-shot) | 67.0% | 2023/03 | -| PanGu-Coder2 15B | 61.6% | 2023/08 | -| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | -| CodeLlama-34b-Python | 53.7% | 2023/08 | -| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | -| CodeLlama-34b | 48.8% | 2023/08 | -| GPT-3.5(zero-shot) | 48.1% | 2022/11 | -| OctoCoder | 46.2% | 2023/08 | -| StarCoder-15B | 33.6% | 2023/05 | -| QWen-14B | 32.3% | 2023/10 | +| 模型 | HumanEval(Pass@1) | 日期 | +|:---------------------------------|:-----------------:|:-------:| +| **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 | +| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | +| **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | +| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 | +| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | +| GPT-4(zero-shot) | 67.0% | 2023/03 | +| PanGu-Coder2 15B | 61.6% | 2023/08 | +| **CodeFuse-Mixtral-8x7B** | **56.1%** | 2024/01 | +| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | +| CodeLlama-34b-Python | 53.7% | 2023/08 | +| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | +| CodeLlama-34b | 48.8% | 2023/08 | +| GPT-3.5(zero-shot) | 48.1% | 2022/11 | +| OctoCoder | 46.2% | 2023/08 | +| StarCoder-15B | 33.6% | 2023/05 | +| QWen-14B | 32.3% | 2023/10 | ## 文章 @@ -82,53 +106,61 @@ **Codefuse-MFTCoder** 是一个开源的多任务代码大语言模型项目,包含代码大模型的模型、数据、训练等。我们希望通过开源,分享交流大语言模型在代码领域的进步。 ### 项目框架 -![img_1.png](./assets/img_1.png) +![img_1.jpg](./assets/img_1.jpg) ### 项目优势 :white_check_mark: **多任务**:一个模型同时支持多个任务,会保证多个任务之间的平衡,甚至可以泛化到新的没有见过的任务上去; :white_check_mark: **多模型**:支持最新的多个开源模型,包括gpt-neox,llama,llama-2,baichuan,Qwen,chatglm2等; -:white_check_mark: **多框架**:同时支持HuggingFace 和 [ATorch 框架](https://github.com/intelligent-machine-learning/dlrover); +:white_check_mark: **多框架**:既支持主流开源的Accelerate+DeepSpeed/FSDP,也支持新开源的[ATorch 框架](https://github.com/intelligent-machine-learning/dlrover); :white_check_mark: **高效微调**:支持LoRA和QLoRA,可以用很少的资源去微调很大的模型,且训练速度能满足几乎所有微调场景; 本项目主要内容如下: - 同时支持单任务SFT(Supervised FineTuning)和MFT(Multi-task FineTuning), 当前开源支持数据均衡,未来将持续开源难易均衡, 收敛均衡等 -- 支持QLoRA低成本高效指令微调、LoRA高效指令微调。 -- 支持绝大部分主流的开源大模型,重点关注代码能力优秀的开源大模型,如Qwen, GPT-Neox, Starcoder, Codegeex2, Code-LLaMA等。 +- 支持QLoRA低成本高效指令微调、LoRA高效指令微调、全量参数高精度微调。 +- 支持绝大部分主流的开源大模型,重点关注代码能力优秀的开源大模型,如DeepSeek-coder, Mistral, Mistral(MoE), Chatglm3, Qwen, GPT-Neox, Starcoder, Codegeex2, Code-LLaMA等。 - 支持lora与base model进行权重合并,推理更便捷。 - 整理并开源2个指令微调数据集:[Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k)和[CodeExercise-Python-27k](https://huggingface.co/datasets/codefuse-ai/CodeExercise-Python-27k)。 -- 开源2个[Codefuse系列指令微调模型权重]:[CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B)和[CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B)。 +- 开源多个[Codefuse系列指令微调模型权重],具体参见我们的huggingface组织和modelscope组织下的模型:[codefuse-ai huggingface](https://huggingface.co/codefuse-ai) or [codefuse-ai 魔搭](https://modelscope.cn/organization/codefuse-ai)。 ## 环境 -首先, 你需要将CUDA(>=11.4, 推荐11.7)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.0.0) +首先, 你需要将CUDA(>=11.4, 推荐12.1)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.1.0) 在requirements.txt下固定了几个主要的python包的版本,执行如下脚本即可: ```bash sh init_env.sh ``` -如果希望使用flash attention, 安装请参考 https://github.com/Dao-AILab/flash-attention +我们强烈建议您安装flash attention(>=2.3.0), 安装请参考 https://github.com/Dao-AILab/flash-attention ## 训练 -🚀 [Huggingface accelerate + deepspeed Codebase for MFT(Multi-task Finetuning)](./mft_peft_hf/README.md) +如果你熟悉大模型训练的各种主流开源资源,例如 ```transformers```, ```DeepSpeed```, ```FSDP```等, 为了用开源项目快速上手高性能微调,我们建议您尝试: + +🚀🚀 [MFTCoder-accelerate: Accelerate + DeepSpeed/FSDP Codebase for MFT(Multi-task Finetuning)](mftcoder_accelerate/README.md) + -🚀 [Atorch Codebase for MFT(Multi-task Finetuning)](./mft_atorch/README.md) +如果你想探索一些新兴的训练框架,可以尝试: + +🚀 [MFTCoder-atorch: Atorch Codebase for MFT(Multi-task Finetuning)](mftcoder_atorch/README.md) ## 模型 -使用本项目的训练代码,以及上述训练数据,我们训练并在huggingface开源了以下模型。 +使用本项目的训练代码,以及上述训练数据,我们训练并在huggingface, modelscope开源了以下模型。 -| 模型 | 基座模型 | 训练数据 | Batch Size | Seq Length | -|---------------------------------------------------------------|----------------------|------|------------|------------| -| [🔥🔥🔥 CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | -| [🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | -| [🔥🔥🔥 CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | -| [🔥🔥🔥 CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 | -| [🔥 CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B) | CodeFuse-13B-Base | 6.6万 | 64 | 4096 | +| 模型 | HuggingFace链接 | 魔搭 链接 | 基座模型 | 训练数据 | Batch Size | Seq Length | +|--------------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|------|------------|------------| +| 🔥🔥🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | +| 🔥🔥🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 | +| 🔥🔥🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 | +| 🔥🔥🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 70万 | 128 | 4096 | @@ -153,3 +185,15 @@ sh init_env.sh } ``` +## 加入我们 + +我们是平台技术事业群AI Native团队,负责蚂蚁蚂蚁集团平台工程的智能化,团队成立3年多以来,支持了蚂蚁集团云计算基础设施智能化运维的升级改造。团队的Mission是,通过世界级的技术创新和影响,构建有广泛用户的算法服务和平台,支撑内外部产品和业务落地。团队秉承创新基因,在支撑业务落地的同时,推动技术影响。3年以来在ICLR、NeurIPS、KDD、ACL等顶会发表论文20余篇,创新业务结果获得两次蚂蚁技术最高奖T-Star,1次蚂蚁集团最高奖SuperMA。开源项目CodeFuse获得4K点赞(2024年2月),Huggingface和modelscope上模型累积下载量超过150万次。 + +**我们正在寻找行业中的佼佼者加入我们的团队!如果您希望在一个充满活力、创新和卓越文化的环境中发展您的职业生涯,欢迎您查看我们的社招&校招机会,加入我们,一起创造下一个行业里程碑。** + +**校招**:https://hrrecommend.antgroup.com/guide.html?code=8uoP5mlus5DqQYbE_EnqcE2FD5JZH21MwvMUIb9mb6X3osXPuBraG54SyM8GLn_7 + +**社招**:https://talent.antgroup.com/off-campus-position?positionId=1933830 + +## 联系我们 +![img_wx.png](./assets/CodeFuse-AI群.png) diff --git "a/assets/CodeFuse-AI\347\276\244.png" "b/assets/CodeFuse-AI\347\276\244.png" new file mode 100644 index 0000000..4e0c0a1 Binary files /dev/null and "b/assets/CodeFuse-AI\347\276\244.png" differ diff --git a/assets/img.jpg b/assets/img.jpg new file mode 100644 index 0000000..199cc8e Binary files /dev/null and b/assets/img.jpg differ diff --git a/assets/img.png b/assets/img.png deleted file mode 100644 index 0614694..0000000 Binary files a/assets/img.png and /dev/null differ diff --git a/assets/img_1.jpg b/assets/img_1.jpg new file mode 100644 index 0000000..bde7dac Binary files /dev/null and b/assets/img_1.jpg differ diff --git a/assets/img_1.png b/assets/img_1.png deleted file mode 100644 index 69ec9fd..0000000 Binary files a/assets/img_1.png and /dev/null differ diff --git a/init_env.sh b/init_env.sh index 7fdebf7..834b38d 100644 --- a/init_env.sh +++ b/init_env.sh @@ -1,4 +1,4 @@ -pip install torch==2.0.1 && \ -pip install tensorboard && \ +pip install torch==2.1.0 && \ +pip install tensorboard==2.11.0 && \ pip install packaging && \ -pip install -r requirements.txt \ No newline at end of file +pip install -r requirements.txt diff --git a/mft_peft_hf/README.md b/mft_peft_hf/README.md deleted file mode 100644 index 0d6705a..0000000 --- a/mft_peft_hf/README.md +++ /dev/null @@ -1,247 +0,0 @@ -# MFTCoder Training: Huggingface accelerate + DeepSpeed Framework -[![Generic badge](https://img.shields.io/badge/🤗-Huggingface%20Repo-green.svg)](https://huggingface.co/codefuse-ai) - - GitHub - - -[[中文]](README_cn.md) [**English**] - -## 1. Updates - -🔥 MFTCoder supports QLoRA/LoRA using Huggingface accelerate + DeepSpeed Framework; - -🔥 MFTCoder supports Multiple Task Finetuning, which is able to balance diffenrent tasks in data level. - -🔥 MFTCoder supports finetuning multiple mainstream open-source base models: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen. - -## 2. Data Format -### 2.1 Training Data Format -The training data is in a uniformed JSONL format, in which each line of data has the following JSON format. The "chat_rounds" field is required, and other fields can be added or removed based on specific needs. - -```json -{ - "id":0, - "data_name":"code-helper", - "chat_rounds":[ - { - "role": "system", - "content": "You are a expert in coding and help answer code questions", - "chat_round_id": 0 - }, - { - "role": "human", - "content": "Write a python function of quick sort", - "chat_round_id": 1 - }, - { - "role": "bot", - "content": "Below is the function of quick sort: ...", - "chat_round_id": 1 - }, - { - "role": "human", - "content": "Explain the code", - "chat_round_id": 2 - }, - { - "role": "bot", - "content": "OK, this code ...", - "chat_round_id": 2 - } - ] -} -``` - -### 2.2 Inference Data Format -The inference data contains strings concatenated by conversation data(system, human and bot contents) in the training data format. -It is used as the data "seen"(before tokenization) by the model in training process. -It is used as input during the inference process as well. -Here is an example format of the concatenated string: - -```python -""" -<|role_start|>system<|role_end|>System instruction -<|role_start|>human<|role_end|>Human 1st round input -<|role_start|>bot<|role_end|>Bot 1st round output -<|role_start|>human<|role_end|>Human 2nd round input -<|role_start|>bot<|role_end|>Bot 2nd round output -... -... -... -<|role_start|>human<|role_end|>Human nth round input -<|role_start|>bot<|role_end|>{Bot output to be genreated} -""" -``` -When applying inference, you always make your input string end with "<|role_start|>bot<|role_end|>" to request the model generating answers. - - -## 3. Model Training -Currently, the "MFTCoder/mft_peft_hf" codebase supports QLoRA instruction fine-tuning, and LoRA instruction fine-tuning. -In theory, this project can be used to train any publicly available model in the HuggingFace Format. - -Here are some excellent pre-trained models weights available on Huggingface that can be finetuned with this codebase: - -🤗 [Latest code pre-trained SOTA, CodeLlama-34b-Python](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) : code-llama-34b, code-llama-34b-python, a new SOTA base model. - -🤗 [Best 10B level pre-trained Code LLM, Starcoder:](https://huggingface.co/bigcode/starcoder) wizardCoder-15B, PanGu-coder2, and other previous SOTA were trained on it. - -🤗 [Multilingual powerhouse, Qwen-7b](https://huggingface.co/Qwen/Qwen-7B): Suitable for multilingual tasks, including Chinese tasks, for instruction fine-tuning. - -You can find the implementations in the ```mft_peft_hf/src``` directory. The entry directory for fine-tuning training is ```mft_peft_hf/src/pefts```, and the entry file for training is ```mft_peft_hf/src/pefts/mft_accelerate.py```. -Configurations are stored in the ```mft_peft_hf/src/pefts/configs``` directory for easy management and modification. - -### 3.1 Tokenization -During training, we concatenate multi-turn dialogues into the following format (also known as the inference data format mentioned earlier) and then tokenize it. In this format, <|role_start|>human<|role_end|> represents the human input (i.e., prompt), <|role_start|>bot<|role_end|> represents the bot output, and represents the eos_token. -You can modify and `````` replace the eos_token based on different models' requirements. - -Here is an example of the concatenated format with prompts: -``` -"<|role_start|>human<|role_end|>input1<|role_start|>bot<|role_end|>target1<|role_start|>human<|role_end|>input2<|role_start|>bot<|role_end|>target2... -``` -During the calculation of loss, we use a ```loss mask``` to ensure that the loss from the input part does not contribute to parameter updates. Only the loss from the ```target``` part is used for updating parameters. -This approach takes full advantage of the benefits of model parallelism, making training more efficient. It also leverages the characteristic of decoder-only models with left-to-right attention. -By including all target parts from multiple turns in a single training iteration, the training process becomes more efficient. - - -### 3.2 LoRA/QLoRA -You can refer to the Lora paper for details about LoRA:[LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf) -You can refer to the Qlora paper for details about QLoRA:[QLORA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/pdf/2305.14314.pdf) - -QLoRA (Quantized LoRA) is a method that combines 4-bit nf4 quantization and additional adapters to achieve a balance between reducing GPU memory consumption and approaching the performance of full-parameter fine-tuning. - -According to the QLoRA paper, this method enables fine-tuning of a 33B model on a single V100 GPU while achieving performance close to that of full-parameter fine-tuning. - -To perform LoRA/QLoRA fine-tuning, you can execute the following command: -```bash -cd mft_peft_hf/src/pefts - -accelerate launch --config_file accelerate_ds_config.yaml mft_accelerate.py --train_config configs/starcoder_train_config.json -``` -The main parameter explanations for the ```configs/*_train_config``` are as follows. You can modify these parameters according to your needs: - -- **load_raw_dataset**: Need to be true at present. Only JSONL format is supported. - -- **data_paths**: Input data paths in a String of list format, e.g., "[path1,path2,path3]". Each path represents a task directory and each task directory contains one or more JSONL data files. - -- **output_dir**: Training output directory to store checkpoints, Lora adapter, etc. - -- **tb_dir**: TensorBoard directory to store logs, metrics, etc. - -- **model_type**: Type of the model to train, e.g., "llama | starcoder | chatglm2 | qwen | gpt_neox". - -- **peft_type**: either "lora" or "qlora". - -- **lora_rank**: Rank value for Lora. - -- **lora_alpha**: Alpha value for Lora. - -- **lora_dropout**: Dropout rate for Lora. - -- **quantization**: Whether to use quantization."4bit" or "8bit", or null. For QLoRA, it is recommended to use 4-bit quantization. - -- **pretrained_model_path**: Local/Shared disk path or model name on HuggingFace for the pre-trained model. - -- **weighted_loss_mode**: Loss weighting method for multitask training. "case3" is recommended at present. - -- **padding_mode**: The way tokenized data is set. "padding" means padding for each sample to seq_length, "pack" means putting samples into seq_length as many as possible. - -- **num_train_epochs**: Number of training epochs. - -- **per_device_train_batch_size**: Batch size per GPU for training. - -- **per_device_eval_batch_size**: Batch size per GPU for evaluation. - -- **gradient_accumulation_steps**: Number of gradient accumulation steps. Global batch size is calculated as num_gpus * per_device_train_batch_size * gradient_accumulation_steps. - -- **learning_rate**: Initial Learning rate. For full-parameter fine-tuning, it is recommended to use a smaller value such as 1e-5 or 5e-6. For QLoRA, a larger learning rate is generally used, such as 1e-4 or 2e-4. - -- **min_lr**: Minimum learning rate. Usually set to one-tenth of the learning rate. - -- **seq_length**: Maximum input sequence length during training. - -- **log_interval**: Log training loss every ```log_interval``` steps. - -- **checkpointing_steps**: Save a checkpoint every ```checkpointing_steps``` steps. - -- **evaluation_steps**: Evaluate on the validation set every ```evaluation_steps``` steps. - -- **early_stopping**: Enable early stopping or not. - -- **early_stopping_stall_num**: Number of evaluation points without improvement which triggers early stopping. - -- **lr_scheduler_type**: Type of learning rate scheduler. "cosine" is a good choice already. - -- **num_warmup_steps**: Number of warm-up steps to gradually increase the learning rate. - -- **seed**: Random seed for reproducibility. - - -## 4. Model Usage - -### 4.1 Merge Adaptor weights -Using LoRA or QLoRA for training, this project only saves the weights and configuration files of the adapters. -To merge the adapter weights with the base model, see ```src/pefts/merge_base_and_lora_to_hf.py``` - -### 4.2 Inference demo -Here is the script for inference on our trained models, which is compatible with most HuggingFace models: -```python -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, -) -tokenizer = AutoTokenizer.from_pretrained(mode_name_or_path, trust_remote_code=True, use_fast=False, legacy=False) -tokenizer.padding_side = "left" -tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("") -tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("") -model = AutoModelForCausalLM.from_pretrained(mode_name_or_path, trust_remote_code=True) - -HUMAN_ROLE_START_TAG = "<|role_start|>human<|role_end|>" -BOT_ROLE_START_TAG = "<|role_start|>bot<|role_end|>" -texts = ["write a python function of quick sort."] -texts = [f"{HUMAN_ROLE_START_TAG}{text}{BOT_ROLE_START_TAG}" for text in texts] - -inputs = tokenizer(texts, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda") -outputs = model.generate( - inputs=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - max_new_tokens=512, - top_p=0.95, - temperature=0.1, - do_sample=True, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id - ) -gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True) -print(gen_text) -``` - - -Indeed, the parameters top_p, temperature, repetition_penalty, do_sample, etc., have a significant impact on the model's generation output. -You can modify these parameters based on your specific use case. - -In code generation scenarios, if you are using the sampling mode (do_sample=True), the following parameter settings can yield good results for the Pass@1 metric: - -top_p: Set a higher value, such as 0.95, to retain highly probable generated words. This helps ensure more accurate and fluent generation results. - -temperature: Set a lower value, such as 0.1, to reduce randomness. Lower temperature values make the generation output more deterministic. - -These parameter combinations can control the diversity of the generated outputs while maintaining naturalness. Additionally, you can adjust other related parameters, such as repetition_penalty, to reduce repetition in the generated results. - -If you choose the non-sampling mode (do_sample=False), you can consider the following parameter settings: - -beam_num: Set a smaller value such as 1 or 3. ```beam_num=1``` represents greedy decoding, which selects the most probable single generated word. ```beam_num=3``` represents beam search mode, which considers multiple potential generation paths and chooses the best path among them. - -## 5. FAQ -#### Q1:What should I do when cuda OOM happens? -If OOM happened,you can reduce parameters such as per_device_train_batch_size and seq_length. Since you are dealing with large models (6B, 13B, 34B, 70B, etc.), you are already using gradient checkpointing technology by default, which significantly reduces GPU memory consumption. -However, this may slightly slow down the training speed. - -#### Q2:install packages -Please refer to init_env.sh and requirements.txt - - -#### Q3:How should I specify the GPUs for training? -You can specify the visiable GPUs as below: -```bash -CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file accelerate_ds_config.yaml mft_accelerate.py --train_config configs/starcoder_train_config.json -``` diff --git a/mft_peft_hf/README_cn.md b/mft_peft_hf/README_cn.md deleted file mode 100644 index 74960d2..0000000 --- a/mft_peft_hf/README_cn.md +++ /dev/null @@ -1,195 +0,0 @@ -# MFTCoder训练: Huggingface accelerate + DeepSpeed框架篇 -[![Generic badge](https://img.shields.io/badge/🤗-Huggingface%20Repo-green.svg)](https://huggingface.co/codefuse-ai) - - GitHub - - -[**中文**] [[English]](README.md) - -## 1. 更新 - -🔥 MFTCoder在Huggingface accelerate + DeepSpeed框架下支持QLoRA/LoRA微调; - -🔥 MFTCoder在训练中支持了多任务微调, 可以同时平衡多个任务的训练,训练的模型支持多任务推理; - -🔥 MFTCoder在训练中支持多种模型基座: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen等 - -## 2. 数据格式 -### 2.1 训练数据格式 -训练数据为jsonl格式,每一行的数据格式如下,其中chat_rounds字段是必需的,可以根据实际需求添加或删除其他字段。 -可以参考项目中的xxx.jsonl文件。 -```json -{ - "id":0, - "data_name":"code-helper", - "chat_rounds":[ - { - "role": "system", - "content": "你是一个智能代码助手,可以回复用户与代码相关的问题", - "chat_round_id": 0 - }, - { - "role": "human", - "content": "写一个快速排序", - "chat_round_id": 1 - }, - { - "role": "bot", - "content": "以下是一个快速排序算法xxxxxx", - "chat_round_id": 1 - }, - { - "role": "human", - "content": "解释一下这段代码", - "chat_round_id": 2 - }, - { - "role": "bot", - "content": "好的,这段代码xxx", - "chat_round_id": 2 - } - ] -} -``` - -### 2.2 推理数据格式 -推理数据格式为模型在训练数据格式下拼接的字符串形式,它也是推理时输入prompt拼接的方式: -```python -""" -<|role_start|>system<|role_end|>这是System指令 -<|role_start|>human<|role_end|>这是第1轮用户输入的问题 -<|role_start|>bot<|role_end|>这是第1轮模型生成的内容 -<|role_start|>human<|role_end|>这是第2轮用户输入的问题 -<|role_start|>bot<|role_end|>这是第2轮模型生成的内容 -... -... -... -<|role_start|>human<|role_end|>这是第n轮用户输入的问题 -<|role_start|>bot<|role_end|>{模型现在要生成的内容} -""" -``` - - -## 3. 模型训练 -目前支持全量参数指令微调、QLoRA指令微调,LoRA指令微调。 -一些优秀的代码预训练模型权重,理论上,HuggingFace上开源的模型,均可使用本项目进行训练: - -🤗 [最新代码预训练SOTA,CodeLlama](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) :code-llama-34b, code-llama-34b-python, 新的SOTA基座。 - -🤗 [10B级别最佳代码预训练模型Starcoder](https://huggingface.co/bigcode/starcoder) wizardCoder-15B, PanGu-coder2等前SOTA的基座模型。 - -🤗 [多语言能手Qwen-7b](https://huggingface.co/Qwen/Qwen-7B) :适用于多语言任务,也适用中文任务。进行指令微调时。 - -我们将训练中使用的各种组件抽取出来,以便后续的扩展和优化,详见src目录下的实现。微调训练的入口目录是```src/pefts```, 训练入口文件是```src/pefts/mft_accelerate.py```, 参数配置存储在```src/pefts/configs```目录下,方便统一管理和更改。 - -### 3.1 数据tokenization -训练时,我们将多轮对话拼接成如下格式(也是上文中的推理string格式),然后进行tokenize。其中<|role_start|>human<|role_end|>表示human输入提示符,<|role_start|>bot<|role_end|>表示bot输出提示符,`````````` 表示eos_token。 -其中eos_token可以根据不同模型修改替换。 -``` -"<|role_start|>human<|role_end|>input1target1input2target2... -``` -在计算loss时,我们通过loss mask的方式,input部分的loss不参与参数更新,只有“target”部分的loss参与参数更新。 -这种方式充分利用了模型并行计算的优势,训练更加高效,同时也充分利用了decoder-only模型从左到右attention的特性,一次性将多轮对话中的每个target部分都参与了训练,训练更充分高效。 - -### 3.2 LoRA/QLoRA微调 -关于LoRA的详细介绍可参考论文:[LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf) -关于QLoRA的详细介绍可参考论文:[QLORA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/pdf/2305.14314.pdf) - -QLoRA通过4-bit的nf4量化,且加入更多adapter,在大幅减少显存消耗的同时,尽可能逼近全量参数微调的效果。 -QLoRA论文指出,该方法可以在一张V100上对33B的模型进行微调,并且性能逼近全量参数微调。 - -执行如下命令即可进行Lora/QLora微调: -```bash -accelerate launch --config_file accelerate_ds_config.yaml mft_accelerate.py --train_config configs/starcoder_train_config.json -``` - -```configs/*_train_config```中的主要参数说明如下,以下参数可以根据需求进行修改,其他参数建议不做修改: -- load_raw_dataset : 需要保持true,后续会支持其它模式数据,当前仅支持jsonl输入 -- data_paths: "[path1,path2,path3]" 输入数据地址,字符串,开头结尾用[],中间用```,```间隔不同path,每个path是一个目录,目录的最后一级名字作为任务名称,下面包含1到多个jsonl数据 -- output_dir:训练输出目录,存储checkpoint、lora_adaptor等 -- tb_dir: 存储tensorboard等 -- model_type: "llama|starcoder|chatglm2|qwen|gpt_nex" -- peft_type: lora或者qlora -- lora_rank: lora rank -- lora_alpha: lora alpha -- lora_dropout: lora dropout -- quantization: 是否量化,"4bit", "8bit" 或者null, qlora推荐4bit量化 -- pretrained_model_path:预训练模型的本地目录,或者在huggingface上的模型名称。 -- **weighted_loss_mode**: 多任务loss加权模式, "case3"是当前推荐。 -- **padding_mode**: 数据的样本组织方式, "padding"是将每个原始样本填充到seq_length, "pack"是将尽量多的样本打包到每个seq_length的序列中。 -- num_train_epochs:训练的轮次。如果数据量足够大,一般建议只训1-2个epoch。 -- per_device_train_batch_size:每张显卡train的batch size。 -- per_device_eval_batch_size:每张显卡eval的batch size。 -- gradient_accumulation_steps:梯度累计步数。global batch=num_gpus * per_device_train_batch_size * gradient_accumulation_steps。 -- learning_rate:学习率。全量参数微调的时候,建议小一些,1e-5或5e-6。qlora中的学习率设置更大一些,一般为1e-4、2e-4。 -- min_lr: 最低学习率, 一般是learning_rate的十分之一 -- seq_length:训练时的最大长度。按照自己的设备进行设置,越长需要占用越多显存。 -- log_interval:每隔多少步统计一次train loss。 -- checkpointing_steps:每隔多少步保存一个模型。 -- evalation_steps:每隔多少步在验证集上evaluate一次。 -- early_stopping : 是否执行early_stop -- early_stopping_stall_num: 多少个eval point不继续收敛,则停止训练 -- lr_scheduler_type:学习率变化策略。常用"cosine" -- warmup_steps:warm up步数。学习率经过多少步,增长到指定的数值。 -- seed:随机种子,用于复现实验结果。 - -## 4. 模型使用 - -### 4.1 权重合并 -如果使用LoRA或者QLoRA进行训练,本项目仅保存adapter的权重和配置文件,需要将adapter权重与base model进行合并。脚本见```src/pefts/merge_base_and_lora_to_hf.py``` - -### 4.2 模型推理 -我们提供了单轮对话和多轮对话的如下脚本,该脚本可同时兼容大部分huggingface格式的模型。 -```python -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, -) -tokenizer = AutoTokenizer.from_pretrained(mode_name_or_path, trust_remote_code=True, use_fast=False, legacy=False) -tokenizer.padding_side = "left" -tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("") -tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("") -model = AutoModelForCausalLM.from_pretrained(mode_name_or_path, trust_remote_code=True) - -HUMAN_ROLE_START_TAG = "<|role_start|>human<|role_end|>" -BOT_ROLE_START_TAG = "<|role_start|>bot<|role_end|>" -texts = ["write a python function of quick sort."] -texts = [f"{HUMAN_ROLE_START_TAG}{text}{BOT_ROLE_START_TAG}" for text in texts] - -inputs = tokenizer(texts, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda") -outputs = model.generate( - inputs=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - max_new_tokens=512, - top_p=0.95, - temperature=0.1, - do_sample=True, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id - ) -gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True) -print(gen_text) -``` - - -生成脚本中的top_p、temperature、repetition_penalty、do_sample等参数对模型的生成效果影响较大,可按照自己的使用场景进行调试修改。 -实践中,在代码生成场景中,如果采样模式,do_sample=True, top_p=0.95, temperature=0.1是pass@1指标的不错选择; -如果非采样模式, do_sample=False, beam_num=1或者3是不错的选择,其中beam_num=1即为greedy decoding。 - -## 5. FAQ -#### 问题1:OOM如何解决? -如果发生OOM,可以缩小per_device_train_batch_size、seq_length等参数来缓解。由于面对的模型普遍较大(6b, 13b, 34b, 70b等)我们已经默认使用gradient_checkpointing技术,可以大幅降低显存占用,但训练速度会稍慢一些。 - -#### 问题2:安装包错误 -参考init_env.sh和requirements.txt - -#### 问题3:如何指定使用某些卡训练? -通过如下方式,即可指定使用0和1号卡进行训练: -```bash -CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file accelerate_ds_config.yaml mft_accelerate.py --train_config configs/starcoder_train_config.json -``` - - - - - diff --git a/mft_peft_hf/src/data/tokenization/preprocess_data.py b/mft_peft_hf/src/data/tokenization/preprocess_data.py deleted file mode 100644 index 841bcb6..0000000 --- a/mft_peft_hf/src/data/tokenization/preprocess_data.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -# @author Chaoyu Chen -# @date 2023/6/19 -Preprocessing data and tokenization. -""" - -import argparse -import multiprocessing -import os -import sys -import numpy as np -import random - -# add src root path -current_path = os.path.abspath(__file__) -parent_dir = os.path.dirname(os.path.dirname(current_path)) -grandparent_dir = os.path.dirname(parent_dir) -sys.path.append(grandparent_dir) -# print(grandparent_dir) - -import data.tokenization.lm_dataformat as lmd - -import time -import tqdm -import torch -import ftfy -import glob - -from tokenizer import build_tokenizer -from threading import Semaphore - -table = {ord(f): ord(t) for f, t in zip( - u',。!?:【】()%#@&1234567890', - u',.!?:[]()%#@&1234567890')} - - -def punctuation_format(text: str): - # Replace non-breaking space with space - # text = text.strip() + '\n' - text = text.replace('\u202f', ' ').replace('\xa0', ' ') - # change chinese punctuation to english ones - text = text.translate(table) - return text - - -def is_prompt_answer_format(data): - if "prompt" in data and "answer" in data: - return True - else: - return False - - -def is_chatml_format(data): - if "chat_rounds" in data and len(data["chat_rounds"]) > 0: - return True - else: - return False - - -def is_text_format(data): - if "text" in data: - return True - else: - return False - - -class Encoder(object): - def __init__(self, args): - self.args = args - - def initializer(self): - # Use Encoder class as a container for global data - # Encoder.tokenizer = build_tokenizer(self.args) - self.tokenizer = build_tokenizer(self.args) - - def encode(self, text): - if self.args.ftfy: - text = ftfy.fix_text(text) - ids = {} - for key in self.args.jsonl_keys: - doc_ids = [] - text_ids = self.tokenizer.encode(text, add_special_tokens=False) - if len(text_ids) > 0: - doc_ids.append(text_ids) - if self.args.append_eod: - doc_ids[-1].append(self.tokenizer.eod_id) - ids[key] = doc_ids - return ids, len(text) - - -class UniformEncoder(Encoder): - def __init__(self, args, mode='sft'): - super().__init__(args) - self.mode = mode - # seq_length + 1 for shifting - if args.load_raw_dataset: - self.seq_length = args.seq_length + 1 - self.stride = args.seq_length - else: - self.seq_length = args.seq_length - - self.remain_input_ids = [] - self.remain_loss_mask = [] - - def encode(self, data): - - encode_res = { - "input_ids": [], - "loss_mask": [] - } - - if is_prompt_answer_format(data): - data_type = 'prompt_answer' - elif is_chatml_format(data): - data_type = 'chatML' - elif is_text_format(data): - data_type = 'text' - else: - raise ValueError("data format not supported, please use prompt/answer, or chatML or pretrain text") - - for token_res in self._tokenize_fields(data, data_type=data_type): - for k, v in token_res.items(): - encode_res[k].append(v) - - length = 0 - if data_type == 'prompt_answer': - length = len(data['prompt']) + len(data['answer']) - elif data_type == 'chatML': - for chat in data['chat_rounds']: - length += len(chat['content']) - elif data_type == 'text': - length += len(data['text']) - - return encode_res, length - - def _tokenize_fields(self, data, data_type): - - CHAT_COL = 'chat_rounds' - ROLE_COL = 'role' - CONTENT_COL = 'content' - - PROMPT_COL = 'prompt' - ANSWER_COL = 'answer' - SYSTEM_COL = 'system' - - TEXT_COL = 'text' - - if self.mode == 'sft': - HUMAN = 'human' - BOT = 'bot' - SYSTEM = 'system' - ROLE_START_MARKER = '<|role_start|>' - ROLE_END_MARKER = '<|role_end|>' - elif self.mode == 'pretrain' or data_type == 'text': - HUMAN = '' - BOT = '' - SYSTEM = '' - ROLE_START_MARKER = '' - ROLE_END_MARKER = '' - else: - raise ValueError(f"tokenize_mode does not support {self.mode}, please use sft or pretrain") - - human_marker_ids = self.tokenizer.encode(f"{ROLE_START_MARKER}{HUMAN}{ROLE_END_MARKER}", add_special_tokens=False) - bot_marker_ids = self.tokenizer.encode(f"{ROLE_START_MARKER}{BOT}{ROLE_END_MARKER}", add_special_tokens=False) - system_marker_ids = self.tokenizer.encode(f"{ROLE_START_MARKER}{SYSTEM}{ROLE_END_MARKER}", add_special_tokens=False) - sft_end_marker_ids = [self.tokenizer.eod_id] - - # uniform SST,SFT,MFT - - input_ids = [] - loss_mask = [] - - if data_type == "prompt_answer": - system = data.get(SYSTEM_COL, '') - prompt = data[PROMPT_COL] - answer = data[ANSWER_COL] - system = punctuation_format(system) - prompt = punctuation_format(prompt) - answer = punctuation_format(answer) - system_ids = system_marker_ids + self.tokenizer.encode(system, add_special_tokens=False) if system else [] - prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) - answer_ids = self.tokenizer.encode(answer, add_special_tokens=False) + sft_end_marker_ids - input_ids += system_ids + human_marker_ids + prompt_ids + bot_marker_ids + answer_ids - loss_mask += [0] * len(system_ids) + [0] * len(human_marker_ids) + [0] * len(prompt_ids) + \ - [0] * len(bot_marker_ids) + [1] * len(answer_ids) - elif data_type == 'chatML': - chat = data[CHAT_COL] - for r in chat: - role = r[ROLE_COL] - content = r[CONTENT_COL] - content = punctuation_format(content) - if role == HUMAN: - role_marker_ids = human_marker_ids - content_ids = self.tokenizer.encode(content, add_special_tokens=False) - elif role == BOT: - # compute loss for eos token after bot's content - role_marker_ids = bot_marker_ids - content_ids = self.tokenizer.encode(content, add_special_tokens=False) + sft_end_marker_ids - elif role == SYSTEM: - role_marker_ids = system_marker_ids - content_ids = self.tokenizer.encode(content, add_special_tokens=False) - else: - raise ValueError(f"Role {role} not supported.") - - input_ids += role_marker_ids + content_ids - masklet = [1] if role == BOT else [0] - loss_mask += [0] * len(role_marker_ids) + masklet * len(content_ids) - elif data_type == "text": - text = data[TEXT_COL] - text = punctuation_format(text) - text_ids = self.tokenizer.encode(text, add_special_tokens=False) + sft_end_marker_ids - input_ids += text_ids - loss_mask += [1] * len(text_ids) - else: - raise ValueError( - f"data_type does not support {self.args.data_type}, please use chatML or prompt_answer or text(for pretrain)") - - # print(self.mode) - if self.mode == 'pretrain': - # change loss mask to all 1s - input_ids = input_ids - loss_mask = [1] * len(loss_mask) - elif self.mode == 'sft': - # do nothing - input_ids = input_ids - loss_mask = loss_mask - - assert len(input_ids) == len(loss_mask) - if self.args.padding_mode == 'padding': - if len(input_ids) <= self.seq_length: - yield self.padding(input_ids, loss_mask) - - # drop if too long - else: - yield {} - elif self.args.padding_mode == 'concat': - input_ids = self.remain_input_ids + input_ids - loss_mask = self.remain_loss_mask + loss_mask - if len(input_ids) < self.seq_length: - self.remain_input_ids = input_ids - self.remain_loss_mask = loss_mask - assert len(self.remain_input_ids) == len(self.remain_loss_mask) - yield {} - else: - cursor = 0 - while cursor + self.seq_length <= len(input_ids): - yield { - "input_ids": input_ids[cursor: cursor + self.seq_length], - "loss_mask": loss_mask[cursor: cursor + self.seq_length] - } - cursor = cursor + self.stride - self.remain_input_ids = input_ids[cursor:] - self.remain_loss_mask = loss_mask[cursor:] - assert len(self.remain_input_ids) == len(self.remain_loss_mask) - yield {} - elif self.args.padding_mode == 'pack': - if len(input_ids) > self.seq_length: - yield {} - elif len(self.remain_input_ids) + len(input_ids) > self.seq_length: - input_ids, self.remain_input_ids = self.remain_input_ids, input_ids - loss_mask, self.remain_loss_mask = self.remain_loss_mask, loss_mask - assert len(input_ids) == len(loss_mask) - yield self.padding(input_ids, loss_mask) - else: - self.remain_input_ids = self.remain_input_ids + input_ids - self.remain_loss_mask = self.remain_loss_mask + loss_mask - assert len(self.remain_input_ids) == len(self.remain_loss_mask) - yield {} - - def padding(self, input_ids, loss_mask): - pad_id = self.tokenizer.pad_id - assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}" - input_ids += [pad_id] * (self.seq_length - len(input_ids)) - loss_mask += [0] * (self.seq_length - len(loss_mask)) - return { - "input_ids": input_ids, - "loss_mask": loss_mask - } - - -def find_jsonl_fnames(inputs): - fnames = [] - for p in inputs.split(","): - if not os.path.isdir(p): - if p.endswith(".jsonl"): - print(f"loading from {p}") - fnames.append(p) - else: - p_list = glob.glob(p + "/*") - for p_ in p_list: - if p_.endswith(".jsonl"): - print(f"loading from {p_}") - fnames.append(p_) - return fnames - - -def yield_from_files(fnames: list, semaphore): - """ - Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts / - other compressed formats. Also filters out empty documents. - - :param fnames: list of filenames - """ - - def yielder(fname, semaphore): - for f in filter(lambda x: x, lmd.Reader(fname).stream_data( - key=['task', 'src_language', 'src_code', 'tgt_language', 'tgt_code', 'sql', 'prompt', 'answer', - 'bad_answer'])): - semaphore.acquire() - yield f - - for fname in fnames: - semaphore.acquire() - - yield from yielder(fname, semaphore) diff --git a/mft_peft_hf/src/model/baichuan/config.json b/mft_peft_hf/src/model/baichuan/config.json deleted file mode 100644 index 3ed1760..0000000 --- a/mft_peft_hf/src/model/baichuan/config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "_from_model_config": true, - "architectures": [ - "BaichuanForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_baichuan.BaichuanConfig", - "AutoModelForCausalLM": "modeling_baichuan.BaichuanForCausalLM" - }, - "bos_token_id": 1, - "eos_token_id": 2, - "gradient_checkpointing": false, - "hidden_act": "silu", - "hidden_size": 5120, - "initializer_range": 0.02, - "intermediate_size": 13696, - "model_max_length": 4096, - "model_type": "baichuan", - "num_attention_heads": 40, - "num_hidden_layers": 40, - "pad_token_id": 0, - "rms_norm_eps": 1e-06, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.29.2", - "use_cache": true, - "vocab_size": 64000 -} diff --git a/mft_peft_hf/src/model/baichuan/modeling_baichuan.py b/mft_peft_hf/src/model/baichuan/modeling_baichuan.py deleted file mode 100644 index abb34d0..0000000 --- a/mft_peft_hf/src/model/baichuan/modeling_baichuan.py +++ /dev/null @@ -1,554 +0,0 @@ -# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. - -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss -from transformers import PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.utils import logging -from transformers.generation.utils import GenerationConfig - -from .configuration_baichuan import BaichuanConfig - -logger = logging.get_logger(__name__) - -def _get_interleave(n): - def _get_interleave_power_of_2(n): - start = (2 ** (-2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio ** i for i in range(n)] - - if math.log2(n).is_integer(): - return _get_interleave_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return _get_interleave_power_of_2(closest_power_of_2) + \ - _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2] - -def _fill_with_neg_inf(t): - """FP16-compatible function that fills a tensor with -inf.""" - return t.float().fill_(float("-inf")).type_as(t) - -def _gen_alibi_mask(n_head, max_pos): - slopes = torch.Tensor(_get_interleave(n_head)) - alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand( - n_head, -1, -1) - alibi = alibi.view(n_head, 1, max_pos) - alibi_mask = torch.triu( - _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1 - ) - alibi_mask = alibi_mask.unsqueeze(0) + alibi - return alibi_mask - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, epsilon=1e-6): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(hidden_size)) - self.epsilon = epsilon - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) - - # convert into half-precision - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class MLP(torch.nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class BaichuanAttention(torch.nn.Module): - - def __init__(self, config: BaichuanConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.model_max_length - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" - ) - self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) - self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - bsz, q_len, _ = hidden_states.size() - - proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - if attn_weights.size(-2) == 1: - attention_mask = attention_mask[:, -1:, :] - attn_weights = attn_weights + attention_mask.unsqueeze(0) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class BaichuanLayer(torch.nn.Module): - def __init__(self, config: BaichuanConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = BaichuanAttention(config=config) - self.mlp = MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class BaichuanPreTrainedModel(PreTrainedModel): - config_class = BaichuanConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BaichuanLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, torch.nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, torch.nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BaichuanModel): - module.gradient_checkpointing = value - - - -class BaichuanModel(BaichuanPreTrainedModel): - def __init__(self, config: BaichuanConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.n_head = config.num_attention_heads - self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) - - self.gradient_checkpointing = config.gradient_checkpointing - self.post_init() - self.max_cache_pos = config.model_max_length - self.first_run = True - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def get_alibi_mask(self, tensor, seq_length_with_past): - if self.first_run: - self.first_run = False - self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False) - if seq_length_with_past > self.max_cache_pos: - self.max_cache_pos = seq_length_with_past - self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False) - mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past] - return mask - - def forward( - self, - input_ids: torch.LongTensor = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - return_dict: Optional[bool] = True, - ) -> Union[Tuple, BaseModelOutputWithPast]: - - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You need to provide input_ids or inputs_embeds") - - seq_length_with_past = seq_length - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class BaichuanForCausalLM(BaichuanPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = BaichuanModel(config) - self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - return_dict: Optional[bool] = True, - **kwargs - ) -> Union[Tuple, CausalLMOutputWithPast]: - - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - return tuple( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) - for layer_past in past_key_values - ) - - - def quantize(self, bits: int): - try: - from .quantizer import QLinear - except ImportError: - raise ImportError( - f"Needs QLinear to run quantize." - ) - - for layer in self.model.layers: - layer.self_attn.W_pack = QLinear( - bits=bits, - weight=layer.self_attn.W_pack.weight, - bias = None, - ) - layer.self_attn.o_proj = QLinear( - bits=bits, - weight=layer.self_attn.o_proj.weight, - bias = None, - ) - layer.mlp.gate_proj = QLinear( - bits=bits, - weight=layer.mlp.gate_proj.weight, - bias = None, - ) - layer.mlp.down_proj = QLinear( - bits=bits, - weight=layer.mlp.down_proj.weight, - bias = None, - ) - layer.mlp.up_proj = QLinear( - bits=bits, - weight=layer.mlp.up_proj.weight, - bias = None, - ) - return self - - def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0): - max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens - max_input_tokens = self.config.model_max_length - max_new_tokens - max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) - total_input, round_input = [], [] - for i, message in enumerate(messages[::-1]): - content_tokens = tokenizer.encode(message['content']) - if message['role'] == 'user': - round_input = [self.generation_config.user_token_id] + content_tokens + round_input - if total_input and len(total_input) + len(round_input) > max_input_tokens: - break - else: - total_input = round_input + total_input - if len(total_input) >= max_input_tokens: - break - else: - round_input = [] - elif message['role'] == 'assistant': - round_input = [ - self.generation_config.assistant_token_id - ] + content_tokens + [ - self.generation_config.eos_token_id - ] + round_input - else: - raise ValueError(f"message role not supported yet: {message['role']}") - total_input = total_input[-max_input_tokens:] # truncate left - total_input.append(self.generation_config.assistant_token_id) - total_input = torch.LongTensor([total_input]).to(self.device) - return total_input - - @torch.no_grad() - def chat(self, tokenizer, messages: List[dict], stream=False, - generation_config: Optional[GenerationConfig]=None): - generation_config = generation_config or self.generation_config - input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens) - if stream: - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig - self.__class__.generate = NewGenerationMixin.generate - self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) - - def stream_generator(): - outputs = [] - for token in self.generate(input_ids, generation_config=stream_config): - outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True) - - return stream_generator() - else: - self.__class__.generate = PreTrainedModel.generate # disable stream - outputs = self.generate(input_ids, generation_config=generation_config) - response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) - return response diff --git a/mft_peft_hf/src/model/baichuan/modeling_baichuan_flash.py b/mft_peft_hf/src/model/baichuan/modeling_baichuan_flash.py deleted file mode 100644 index 792999c..0000000 --- a/mft_peft_hf/src/model/baichuan/modeling_baichuan_flash.py +++ /dev/null @@ -1,605 +0,0 @@ -# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. - -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss -from transformers import PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.utils import logging -from transformers.generation.utils import GenerationConfig - -from .configuration_baichuan import BaichuanConfig -import xformers.ops - -logger = logging.get_logger(__name__) - -def _get_interleave(n): - def _get_interleave_power_of_2(n): - start = (2 ** (-2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio ** i for i in range(n)] - - if math.log2(n).is_integer(): - return _get_interleave_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return _get_interleave_power_of_2(closest_power_of_2) + \ - _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2] - -def _fill_with_neg_inf(t): - """FP16-compatible function that fills a tensor with -inf.""" - return t.float().fill_(float("-inf")).type_as(t) - -def _gen_alibi_mask(n_head, max_pos): - slopes = torch.Tensor(_get_interleave(n_head)) - alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand( - n_head, -1, -1) - alibi = alibi.view(n_head, 1, max_pos) - alibi_mask = torch.triu( - _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1 - ) - alibi_mask = alibi_mask.unsqueeze(0) + alibi - return alibi_mask - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, epsilon=1e-6): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(hidden_size)) - self.epsilon = epsilon - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) - - # convert into half-precision - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class MLP(torch.nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class BaichuanAttention(torch.nn.Module): - - def __init__(self, config: BaichuanConfig): - super().__init__() - self.config = config - self.use_xformers = False # config.use_xformers - self.use_triton = False # config.use_triton - self.use_mqa = False # config.use_mqa - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.model_max_length - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" - ) - if not self.use_mqa: - self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) - else: - self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) - self.k_proj = nn.Linear(self.hidden_size, self.head_dim) - self.v_proj = nn.Linear(self.hidden_size, self.head_dim) - self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - bsz, q_len, _ = hidden_states.size() - - if not self.use_mqa: - proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - else: - # MQA - query_states = q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = k_proj(hidden_states).view(bsz, q_len, 1, self.head_dim).transpose(1, 2) - value_states = v_proj(hidden_states).view(bsz, q_len, 1, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # xformers - if self.use_xformers: - # https://facebookresearch.github.io/xformers/components/ops.html - attn_weights = None - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - print(f"shape of attention mask before: {attention_mask.shape}") - # [query_states.shape[0], query_states.shape[2], query_states.shape[1], key_states.shape[1]] - attention_mask = attention_mask.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz, self.num_heads, q_len, kv_seq_len).contiguous() - print(f"shape of attention mask after: {attention_mask.shape}") - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, - attn_bias=xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias(attention_mask), - ) - attn_output = attn_output.contiguous().view(bsz, q_len, -1) - elif self.use_triton: - # Import the triton implementation (torch.nn.functional version only) - from flash_attn.flash_attn_triton import flash_attn_func - """ - q: (batch_size, seqlen_q, nheads, headdim) - k, v: (batch_size, seqlen_k, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). - ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) - """ - attention_mask = attention_mask.unsqueeze(0)[:, :, 0:1, :] - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if self.use_mqa: - key_states = key_states.repeat(1, 1, self.num_heads, 1) - value_states = value_states.repeat(1, 1, self.nnum_heads, 1) - attn_output = flash_attn_func(query_states, key_states, value_states, attention_mask, True) - attn_weights = None - else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - if attn_weights.size(-2) == 1: - attention_mask = attention_mask[:, -1:, :] - attn_weights = attn_weights + attention_mask.unsqueeze(0) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class BaichuanLayer(torch.nn.Module): - def __init__(self, config: BaichuanConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = BaichuanAttention(config=config) - self.mlp = MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class BaichuanPreTrainedModel(PreTrainedModel): - config_class = BaichuanConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BaichuanLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, torch.nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, torch.nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BaichuanModel): - module.gradient_checkpointing = value - - - -class BaichuanModel(BaichuanPreTrainedModel): - def __init__(self, config: BaichuanConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.n_head = config.num_attention_heads - self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) - - self.gradient_checkpointing = config.gradient_checkpointing - self.post_init() - self.max_cache_pos = config.model_max_length - self.first_run = True - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def get_alibi_mask(self, tensor, seq_length_with_past): - if self.first_run: - self.first_run = False - self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False) - if seq_length_with_past > self.max_cache_pos: - self.max_cache_pos = seq_length_with_past - self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False) - mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past] - return mask - - def forward( - self, - input_ids: torch.LongTensor = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - return_dict: Optional[bool] = True, - ) -> Union[Tuple, BaseModelOutputWithPast]: - - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You need to provide input_ids or inputs_embeds") - - seq_length_with_past = seq_length - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class BaichuanForCausalLM(BaichuanPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = BaichuanModel(config) - self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - return_dict: Optional[bool] = True, - **kwargs - ) -> Union[Tuple, CausalLMOutputWithPast]: - - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - return tuple( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) - for layer_past in past_key_values - ) - - - def quantize(self, bits: int): - try: - from .quantizer import QLinear - except ImportError: - raise ImportError( - f"Needs QLinear to run quantize." - ) - - for layer in self.model.layers: - layer.self_attn.W_pack = QLinear( - bits=bits, - weight=layer.self_attn.W_pack.weight, - bias = None, - ) - layer.self_attn.o_proj = QLinear( - bits=bits, - weight=layer.self_attn.o_proj.weight, - bias = None, - ) - layer.mlp.gate_proj = QLinear( - bits=bits, - weight=layer.mlp.gate_proj.weight, - bias = None, - ) - layer.mlp.down_proj = QLinear( - bits=bits, - weight=layer.mlp.down_proj.weight, - bias = None, - ) - layer.mlp.up_proj = QLinear( - bits=bits, - weight=layer.mlp.up_proj.weight, - bias = None, - ) - return self - - def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0): - max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens - max_input_tokens = self.config.model_max_length - max_new_tokens - max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) - total_input, round_input = [], [] - for i, message in enumerate(messages[::-1]): - content_tokens = tokenizer.encode(message['content']) - if message['role'] == 'user': - round_input = [self.generation_config.user_token_id] + content_tokens + round_input - if total_input and len(total_input) + len(round_input) > max_input_tokens: - break - else: - total_input = round_input + total_input - if len(total_input) >= max_input_tokens: - break - else: - round_input = [] - elif message['role'] == 'assistant': - round_input = [ - self.generation_config.assistant_token_id - ] + content_tokens + [ - self.generation_config.eos_token_id - ] + round_input - else: - raise ValueError(f"message role not supported yet: {message['role']}") - total_input = total_input[-max_input_tokens:] # truncate left - total_input.append(self.generation_config.assistant_token_id) - total_input = torch.LongTensor([total_input]).to(self.device) - return total_input - - @torch.no_grad() - def chat(self, tokenizer, messages: List[dict], stream=False, - generation_config: Optional[GenerationConfig]=None): - generation_config = generation_config or self.generation_config - input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens) - if stream: - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig - self.__class__.generate = NewGenerationMixin.generate - self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) - - def stream_generator(): - outputs = [] - for token in self.generate(input_ids, generation_config=stream_config): - outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True) - - return stream_generator() - else: - self.__class__.generate = PreTrainedModel.generate # disable stream - outputs = self.generate(input_ids, generation_config=generation_config) - response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) - return response diff --git a/mft_peft_hf/src/model/llama2/__init__.py b/mft_peft_hf/src/model/llama2/__init__.py deleted file mode 100644 index cfeaa6b..0000000 --- a/mft_peft_hf/src/model/llama2/__init__.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING - -from transformers.utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_sentencepiece_available, - is_tokenizers_available, - is_torch_available, -) - - -_import_structure = { - "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], -} - -try: - if not is_sentencepiece_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["tokenization_llama"] = ["LlamaTokenizer"] - -try: - if not is_tokenizers_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"] - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_llama"] = [ - "LlamaForCausalLM", - "LlamaModel", - "LlamaPreTrainedModel", - "LlamaForSequenceClassification", - ] - - -if TYPE_CHECKING: - from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig - - try: - if not is_sentencepiece_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .tokenization_llama import LlamaTokenizer - - try: - if not is_tokenizers_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .tokenization_llama_fast import LlamaTokenizerFast - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel - - -else: - import sys - - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/mft_peft_hf/src/model/llama2/convert_llama_weights_to_hf.py b/mft_peft_hf/src/model/llama2/convert_llama_weights_to_hf.py deleted file mode 100644 index 75760f4..0000000 --- a/mft_peft_hf/src/model/llama2/convert_llama_weights_to_hf.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import gc -import json -import os -import shutil -import warnings - -import torch - -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - - -try: - from transformers import LlamaTokenizerFast -except ImportError as e: - warnings.warn(e) - warnings.warn( - "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" - ) - LlamaTokenizerFast = None - -""" -Sample usage: - -``` -python src/transformers/models/llama/convert_llama_weights_to_hf.py \ - --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path -``` - -Thereafter, models can be loaded via: - -```py -from transformers import LlamaForCausalLM, LlamaTokenizer - -model = LlamaForCausalLM.from_pretrained("/output/path") -tokenizer = LlamaTokenizer.from_pretrained("/output/path") -``` - -Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions -come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). -""" - -INTERMEDIATE_SIZE_MAP = { - "7B": 11008, - "13B": 13824, - "30B": 17920, - "65B": 22016, - "70B": 28672, -} -NUM_SHARDS = { - "7B": 1, - "7Bf": 1, - "13B": 2, - "13Bf": 2, - "30B": 4, - "65B": 8, - "70B": 8, - "70Bf": 8, -} - - -def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): - return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) - - -def read_json(path): - with open(path, "r") as f: - return json.load(f) - - -def write_json(text, path): - with open(path, "w") as f: - json.dump(text, f) - - -def write_model(model_path, input_base_path, model_size, safe_serialization=True): - os.makedirs(model_path, exist_ok=True) - tmp_model_path = os.path.join(model_path, "tmp") - os.makedirs(tmp_model_path, exist_ok=True) - - params = read_json(os.path.join(input_base_path, "params.json")) - num_shards = NUM_SHARDS[model_size] - n_layers = params["n_layers"] - n_heads = params["n_heads"] - n_heads_per_shard = n_heads // num_shards - dim = params["dim"] - dims_per_head = dim // n_heads - base = 10000.0 - inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - - if "n_kv_heads" in params: - num_key_value_heads = params["n_kv_heads"] # for GQA / MQA - num_local_key_value_heads = n_heads_per_shard // num_key_value_heads - key_value_dim = dim // num_key_value_heads - else: # compatibility with other checkpoints - num_key_value_heads = n_heads - num_local_key_value_heads = n_heads_per_shard - key_value_dim = dim - - # permute for sliced rotary - def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): - return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - print(f"Fetching all parameters from the checkpoint at {input_base_path}.") - # Load weights - if model_size == "7B": - # Not sharded - # (The sharded implementation would also work, but this is simpler.) - loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") - else: - # Sharded - loaded = [ - torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") - for i in range(num_shards) - ] - param_count = 0 - index_dict = {"weight_map": {}} - for layer_i in range(n_layers): - filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" - if model_size == "7B": - # Unsharded - state_dict = { - f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( - loaded[f"layers.{layer_i}.attention.wq.weight"] - ), - f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( - loaded[f"layers.{layer_i}.attention.wk.weight"] - ), - f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], - f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], - f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], - f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], - f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], - f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], - f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], - } - else: - # Sharded - # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share - # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is - # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. - - state_dict = { - f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ - f"layers.{layer_i}.attention_norm.weight" - ].clone(), - f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ - f"layers.{layer_i}.ffn_norm.weight" - ].clone(), - } - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) - for i in range(num_shards) - ], - dim=0, - ).reshape(dim, dim) - ) - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( - num_local_key_value_heads, dims_per_head, dim - ) - for i in range(num_shards) - ], - dim=0, - ).reshape(key_value_dim, dim), - num_key_value_heads, - key_value_dim, - dim, - ) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( - num_local_key_value_heads, dims_per_head, dim - ) - for i in range(num_shards) - ], - dim=0, - ).reshape(key_value_dim, dim) - - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 - ) - state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 - ) - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 - ) - state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 - ) - - state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(tmp_model_path, filename)) - - filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" - if model_size == "7B": - # Unsharded - state_dict = { - "model.embed_tokens.weight": loaded["tok_embeddings.weight"], - "model.norm.weight": loaded["norm.weight"], - "lm_head.weight": loaded["output.weight"], - } - else: - state_dict = { - "model.norm.weight": loaded[0]["norm.weight"], - "model.embed_tokens.weight": torch.cat( - [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 - ), - "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), - } - - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(tmp_model_path, filename)) - - # Write configs - index_dict["metadata"] = {"total_size": param_count * 2} - write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) - ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 - multiple_of = params["multiple_of"] if "multiple_of" in params else 256 - config = LlamaConfig( - hidden_size=dim, - intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), - num_attention_heads=params["n_heads"], - num_hidden_layers=params["n_layers"], - rms_norm_eps=params["norm_eps"], - num_key_value_heads=num_key_value_heads, - ) - config.save_pretrained(tmp_model_path) - - # Make space so we can load the model properly now. - del state_dict - del loaded - gc.collect() - - print("Loading the checkpoint in a Llama model.") - model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - # Avoid saving this as part of the config. - del model.config._name_or_path - - print("Saving in the Transformers format.") - model.save_pretrained(model_path, safe_serialization=safe_serialization) - shutil.rmtree(tmp_model_path) - - -def write_tokenizer(tokenizer_path, input_tokenizer_path): - # Initialize the tokenizer based on the `spm` model - tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast - print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") - tokenizer = tokenizer_class(input_tokenizer_path) - tokenizer.save_pretrained(tokenizer_path) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--input_dir", - help="Location of LLaMA weights, which contains tokenizer.model and model folders", - ) - parser.add_argument( - "--model_size", - choices=["7B", "7Bf", "13B", "13Bf", "30B", "65B", "70B", "70Bf", "tokenizer_only"], - ) - parser.add_argument( - "--output_dir", - help="Location to write HF model and tokenizer", - ) - parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") - args = parser.parse_args() - if args.model_size != "tokenizer_only": - write_model( - model_path=args.output_dir, - input_base_path=os.path.join(args.input_dir, args.model_size), - model_size=args.model_size, - safe_serialization=args.safe_serialization, - ) - spm_path = os.path.join(args.input_dir, "tokenizer.model") - write_tokenizer(args.output_dir, spm_path) - - -if __name__ == "__main__": - main() diff --git a/mft_peft_hf/src/model/llama2/tokenization_llama_fast.py b/mft_peft_hf/src/model/llama2/tokenization_llama_fast.py deleted file mode 100644 index fa8cfdd..0000000 --- a/mft_peft_hf/src/model/llama2/tokenization_llama_fast.py +++ /dev/null @@ -1,248 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from shutil import copyfile -from typing import TYPE_CHECKING, Optional, Tuple - -from tokenizers import processors - -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -from transformers.utils import is_sentencepiece_available, logging -from transformers.utils.versions import require_version - - -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - -require_version("tokenizers>=0.13.3") - -if is_sentencepiece_available(): - from .tokenization_llama import LlamaTokenizer -else: - LlamaTokenizer = None - -logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} - -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" - -# fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\ -answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ -that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\ -correct. If you don't know the answer to a question, please don't share false information.""" -# fmt: on - - -class LlamaTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. - - This uses notably ByteFallback and no normalization. - - ``` - from transformers import LlamaTokenizerFast - - tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer.encode("Hello this is a test") - >>> [1, 15043, 445, 338, 263, 1243] - ``` - - If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or - call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the - values of the first token and final token of an encoded sequence will not be correct). For more details, checkout - [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. - - - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should - refer to this superclass for more information regarding those methods. - - Args: - vocab_file (`str`): - [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that - contains the vocabulary necessary to instantiate a tokenizer. - tokenizer_file (`str`): - [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that - contains everything needed to load the tokenizer. - - clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): - Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra - spaces. - - bos_token (`str`, *optional*, defaults to `""`): - The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. - - eos_token (`str`, *optional*, defaults to `""`): - The end of sequence token. - - unk_token (`str`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - """ - - vocab_files_names = VOCAB_FILES_NAMES - slow_tokenizer_class = LlamaTokenizer - padding_side = "left" - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file=None, - tokenizer_file=None, - clean_up_tokenization_spaces=False, - unk_token="", - bos_token="", - eos_token="", - add_bos_token=True, - add_eos_token=False, - **kwargs, - ): - super().__init__( - vocab_file=vocab_file, - tokenizer_file=tokenizer_file, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - unk_token=unk_token, - bos_token=bos_token, - eos_token=eos_token, - **kwargs, - ) - self._add_bos_token = add_bos_token - self._add_eos_token = add_eos_token - self.update_post_processor() - - self.vocab_file = vocab_file - self.can_save_slow_tokenizer = False if not self.vocab_file else True - - def update_post_processor(self): - """ - Updates the underlying post processor with the current `bos_token` and `eos_token`. - """ - bos = self.bos_token - bos_token_id = self.bos_token_id - - eos = self.eos_token - eos_token_id = self.eos_token_id - - single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}" - pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}" - - special_tokens = [] - if self.add_bos_token: - special_tokens.append((bos, bos_token_id)) - if self.add_eos_token: - special_tokens.append((eos, eos_token_id)) - self._tokenizer.post_processor = processors.TemplateProcessing( - single=single, pair=pair, special_tokens=special_tokens - ) - - @property - def add_eos_token(self): - return self._add_eos_token - - @property - def add_bos_token(self): - return self._add_bos_token - - @add_eos_token.setter - def add_eos_token(self, value): - self._add_eos_token = value - self.update_post_processor() - - @add_bos_token.setter - def add_bos_token(self, value): - self._add_bos_token = value - self.update_post_processor() - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - if not self.can_save_slow_tokenizer: - raise ValueError( - "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " - "tokenizer." - ) - - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): - copyfile(self.vocab_file, out_vocab_file) - - return (out_vocab_file,) - - def _build_conversation_input_ids(self, conversation: "Conversation"): - """Builds the input ids for a conversation. - This is the format used in the provided examples. System prompts should be manually added at the beginning of - the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. - ``` - [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer - [INST] Prompt [/INST] Answer - [INST] Prompt [/INST] - ``` - - If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: - ```python - >>> from transformers import Conversation - - >>> Conversation( - ... "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?" - ... ) - ``` - Args: - conversation (`Conversation`): - Conversation to build input ids for. - Returns: - `List[int]`: - Input ids for the conversation. - """ - dialogue = list(conversation.iter_texts()) - if not all([is_user for is_user, msg in dialogue[::2]]) or not all( - [not is_user for is_user, msg in dialogue[1::2]] - ): - raise ValueError( - "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" - ) - - dialog_tokens = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1]) - - dialog_tokens += sum( - [ - [self.bos_token_id] - + self.encode( - f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False - ) - + [self.eos_token_id] - for prompt, answer in zip(dialogue[::2], dialogue[1::2]) - ], - [], - ) - if not (dialogue[-1][0]): - raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") - dialog_tokens += [self.bos_token_id] + self.encode( - f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False - ) - return dialog_tokens diff --git a/mft_peft_hf/src/pefts/configs/qwen_train_config.json b/mft_peft_hf/src/pefts/configs/qwen_train_config.json deleted file mode 100644 index af73ed4..0000000 --- a/mft_peft_hf/src/pefts/configs/qwen_train_config.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "load_raw_dataset": true, - "data_paths": "$DATA_PATHS", - "output_dir": "$OUTPUT_DIR", - "tb_dir": "$TensorBoard_DIR", - "pretrained_model_path": "$MODEL_NAME_OR_PATH", - "vocab_file": "$MODEL_NAME_OR_PATH", - "low_cpu_mem_usage": true, - "data_split": "95,5,0", - "padding_mode": "padding", - "tokenize_mode": "sft", - "weighted_loss_mode": "case3", - "shuffle_before_split": true, - "use_random_sampler": true, - "early_stopping": true, - "early_stopping_stall_num": 5, - "weight_by_num_documents": true, - "make_vocab_size_divisible_by": 128, - "model_parallel_size": 1, - "model_type": "qwen", - "peft_type": "lora", - "lora_rank": 32, - "lora_alpha": 32, - "lora_dropout": 0.05, - "quantization": "16bit", - "tokenizer_type": "AutoTokenizer", - "use_slow_tokenizer": false, - "use_xformers": true, - "trust_remote_code": true, - "use_dynamic_padding": true, - "per_device_train_batch_size": 4, - "per_device_eval_batch_size": 4, - "world_size": 128, - "learning_rate": 1e-4, - "min_lr": 1e-5, - "weight_decay": 0.1, - "gradient_accumulation_steps": 1, - "lr_scheduler_type": "cosine", - "num_warmup_steps": 1000, - "num_train_epochs": 4, - "seed": 1234, - "seq_length": 2048, - "preprocessing_num_workers": 2, - "num_workers": 2, - "resume_from_checkpoint": null, - "log_interval": 10, - "checkpointing_steps": 1000, - "evalation_steps": 1000, - "max_train_steps": null, - "epoch_checkpointing": false, - "checkpoint_activations": true -} \ No newline at end of file diff --git a/mft_peft_hf/src/pefts/configs/starcoder_train_config.json b/mft_peft_hf/src/pefts/configs/starcoder_train_config.json deleted file mode 100644 index 0d0f363..0000000 --- a/mft_peft_hf/src/pefts/configs/starcoder_train_config.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - "load_raw_dataset": true, - "data_paths": "$DATA_PATHS", - "output_dir": "$OUTPUT_DIR", - "tb_dir": "$TensorBoard_DIR", - "pretrained_model_path": "$MODEL_NAME_OR_PATH", - "vocab_file": "$MODEL_NAME_OR_PATH", - "low_cpu_mem_usage": true, - "data_split": "95,5,0", - "padding_mode": "pack", - "use_dynamic_padding": true, - "tokenize_mode": "sft", - "weighted_loss_mode": "case3", - "model_type": "starcoder", - "peft_type": "lora", - "lora_rank": 32, - "lora_alpha": 32, - "lora_dropout": 0.05, - "quantization": null, - "per_device_train_batch_size": 4, - "per_device_eval_batch_size": 4, - "tokenizer_type": "AutoTokenizer", - "learning_rate": 1e-04, - "min_lr": 1e-5, - "weight_decay": 0.1, - "gradient_accumulation_steps": 1, - "lr_scheduler_type": "cosine", - "num_warmup_steps": 100, - "num_train_epochs": 8, - "seed": 1234, - "seq_length": 4096, - "resume_from_checkpoint": null, - "log_interval": 50, - "checkpointing_steps": 1000, - "evalation_steps": 1000, - "max_train_steps": null, - "epoch_checkpointing": false, - "shuffle_before_split": true, - "use_random_sampler": true, - "early_stopping": true, - "early_stopping_stall_num": 5, - "weight_by_num_documents": true, - "make_vocab_size_divisible_by": 128, - "model_parallel_size": 1, - "use_slow_tokenizer": false, - "use_xformers": true, - "trust_remote_code": true, - "world_size": 128 -} \ No newline at end of file diff --git a/mft_peft_hf/src/pefts/merge_base_and_lora_to_hf.py b/mft_peft_hf/src/pefts/merge_base_and_lora_to_hf.py deleted file mode 100644 index b46c194..0000000 --- a/mft_peft_hf/src/pefts/merge_base_and_lora_to_hf.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -# @author Chaoyu Chen -# @date 2023/6/19 - -""" -import os -import sys -import time -import shutil -import torch -import transformers -sys.path.append("..") -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel -from peft import LoraConfig, get_peft_model -from peft import PeftModel -from model_mapping import MODEL_SPECIAL_TOKENS - - -model_path='path to base model' -lora_adapter='path to lora adaptor ckpt' -save_path='path to new merged model' -model_type = 'llama/gpt_neox/qwen/chatglm2/starcoder' - -t0 = time.time() -config = {"model_type": model_type} -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - - -base_model = AutoModelForCausalLM.from_pretrained( - model_path, - trust_remote_code=True, - torch_dtype=torch.bfloat16, - return_dict=True, - device_map="auto" -) -print(base_model) - -# DEAL with eos_token_id and pad_token_id -eos_token = MODEL_SPECIAL_TOKENS[config['model_type']]['eos_token'] -pad_token = MODEL_SPECIAL_TOKENS[config['model_type']]['pad_token'] -base_model.config.eos_token = eos_token -base_model.config.pad_token = pad_token -base_model.config.eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) -base_model.config.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) -print(f"Finetuned eos_token: {eos_token}, eos_token_id: {tokenizer.convert_tokens_to_ids(eos_token)}") -print(f"Finetuned pad_token: {pad_token}, pad_token_id: {tokenizer.convert_tokens_to_ids(pad_token)}") - - -# merge, save model and tokenizer -model_to_merge = PeftModel.from_pretrained(base_model, lora_adapter) -merged_model = model_to_merge.merge_and_unload() -print(merged_model.config) -merged_model.save_pretrained(save_path) -tokenizer.save_pretrained(save_path) -print(f"Merge finised: {save_path} saved, Cost {time.time()-t0:.2f}s") \ No newline at end of file diff --git a/mft_peft_hf/src/pefts/mft_accelerate.py b/mft_peft_hf/src/pefts/mft_accelerate.py deleted file mode 100644 index 1c29bc7..0000000 --- a/mft_peft_hf/src/pefts/mft_accelerate.py +++ /dev/null @@ -1,385 +0,0 @@ -""" -# @author Chaoyu Chen -# @date 2023/6/19 -# @module mft_accelerate.py - -Hugging face accelerate + deepspeed zero stage2 + DP -QLoRA + MFT entry -""" - -import gc -import os -import sys -import argparse -import math -import logging -import json -import time -import transformers -import numpy as np -import torch -from torch import nn -from dataclasses import dataclass -from datasets import Dataset -import datasets -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - LlamaTokenizer, - get_linear_schedule_with_warmup, - set_seed, - BitsAndBytesConfig, - get_scheduler, -) -from peft import ( - LoraConfig, - TaskType, - get_peft_model, - prepare_model_for_int8_training, - PeftModel, -) -from accelerate import Accelerator -from accelerate.logging import get_logger - -sys.path.append("..") -from data.gpt2_multi_task_dataset import load_dataset_from_jsonl, compile_helper -from utils.common_utils import generate_task_id, TASK2ID, ID2TASK -from train_utils import accelerate_train -from model_mapping import MODEL_TYPES, QLORA_TARGETING_MODULES, MODEL_SPECIAL_TOKENS -logger = get_logger(__name__) - - -def get_task_mask(args, task_id): - task_num = len(TASK2ID) - task_mask = torch.zeros(task_id.shape[0], task_num) - task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1 - - return task_mask - - -def get_ltor_masks_and_position_ids(data): - """Build masks and position id for left to right model.""" - - # Extract batch size and sequence length. - batch_size, seq_length = data.size() - - attention_mask = torch.ones((batch_size, seq_length), device=data.device) - - # Position ids. - position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) - position_ids = position_ids.unsqueeze(0).expand_as(data).clone() - - return attention_mask, position_ids - - -@dataclass -class DataCollatorForMFTDataset(object): - args: None - - # tokenizer: None - - def __call__(self, instances): - input_ids, loss_mask, weights, task_id = tuple( - [instance[key] if key in instance else None for instance in instances] for key in - ("input_ids", "loss_mask", "weight", "task_id")) - - result_batch = {} - ''' - outputs = model( - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']), - # labels=(batch['labels'], batch['loss_mask']), - position_ids=batch['position_ids'], - ) - ''' - - # if loss_mask is not None: - loss_mask = torch.tensor(np.array(loss_mask)) - if self.args.use_dynamic_padding: - last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1) - # get last non-padding position - max_pos = last_one_pos.max().item() + 1 - else: - max_pos = loss_mask.shape[-1] - result_batch['loss_mask'] = loss_mask.float()[:, 1:max_pos].contiguous() - - input_ids = torch.tensor(np.array(input_ids)).long() - # print(f"shape of input_ids: {input_ids.shape}") - result_batch['input_ids'] = input_ids[:, :max_pos-1].contiguous() - result_batch['labels'] = input_ids[:, 1:max_pos].contiguous() - - # Get the masks and position ids. - result_batch['attention_mask'], result_batch['position_ids'] = get_ltor_masks_and_position_ids( - data=result_batch['input_ids']) - - if task_id is not None: - task_id = torch.tensor(np.array(task_id)) - result_batch['task_mask'] = get_task_mask(self.args, task_id) # bsz * task_num - result_batch['task_id'] = task_id - - return result_batch - - -def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): - """ - This method wraps the entire protocol for preparing a model before running a training. This includes: - 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm - head to fp32 - - Args: - model, (`transformers.PreTrainedModel`): - The loaded model from `transformers` - """ - loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) - is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" - for name, param in model.named_parameters(): - # freeze base model's layers - param.requires_grad = False - - if not is_gptq_quantized: - # cast all non INT8 parameters to fp32 - for param in model.parameters(): - if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): - param.data = param.data.to(torch.float32) - - if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing: - # For backward compatibility - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # enable gradient checkpointing for memory efficiency - model.gradient_checkpointing_enable() - - return model - - -def pprint_args(args, accelerator): - message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) - accelerator.print('====' * 30) - accelerator.print(message) - accelerator.print('====' * 30) - accelerator.print("GPU: {}".format(torch.cuda.current_device())) - - -def get_configs(): - - parser = argparse.ArgumentParser() - parser.add_argument("--train_config", type=str, default='./train_config.json') - - parser.add_argument("--data_paths", type=str, default='') - parser.add_argument("--output_dir", type=str, default='') - parser.add_argument("--tb_dir", type=str, default='') - parser.add_argument("--pretrained_model_path", type=str, default='') - - return parser.parse_args() - - -def main(): - t0 = time.time() - parser = get_configs() - train_config_file = parser.train_config - # get configs - with open(train_config_file, 'r') as f: - train_config = json.load(f) - - args = argparse.Namespace(**train_config) - - # get eos token和 pad token - args.eos_token = MODEL_SPECIAL_TOKENS[args.model_type]['eos_token'] - args.pad_token = MODEL_SPECIAL_TOKENS[args.model_type]['pad_token'] - - # refactor args - if parser.data_paths: - args.data_paths = parser.data_paths - if parser.output_dir: - args.output_dir = parser.output_dir - if parser.tb_dir: - args.tb_dir = parser.tb_dir - if parser.pretrained_model_path: - args.pretrained_model_path = parser.pretrained_model_path - args.vocab_file = parser.pretrained_model_path - - if args.peft_type == 'qlora' and args.quantization != '4bit' and args.quantization != '8bit': - print(f"[WARNING]peft_type is qlora but quantization is not 4bit or 8bit, setting it to 4bit") - args.quantization = '4bit' - - # define accelerator - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - pprint_args(args, accelerator) - - # logger - logging.basicConfig( - format="%(asctime)s-%(levelname)s-%(name)s %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - compile_helper() - time.sleep(10) - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - time.sleep(10) - - if args.seed is not None: - set_seed(args.seed) - - # get world_size and global_rank - # args.world_size = int(os.environ.get('WORLD_SIZE', 1)) - # global_rank = int(os.environ.get('RANK', 0)) - args.world_size = accelerator.num_processes - global_rank = accelerator.process_index - print(f'world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {accelerator.local_process_index}') - - # TASK2ID, ID2TASK - generate_task_id(args.data_paths) - # # multi task blendable dataset(sharded) - train_dataset, valid_dataset = load_dataset_from_jsonl(args, shard_data=True, world_size=args.world_size, - global_rank=global_rank, local_rank=accelerator.local_process_index) - t1 = time.time() - logger.info(f"dataset loading time: {t1 - t0:.4f}") - - # memory - free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - max_memory = f"{free_in_GB - 2}GB" - n_gpus = torch.cuda.device_count() - max_memory = {i: max_memory for i in range(n_gpus)} - accelerator.print("max memory: ", max_memory, n_gpus) - - # peft config - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=QLORA_TARGETING_MODULES[args.model_type], - - ) - - # creating base model - ModelClass = MODEL_TYPES[args.model_type] - model = ModelClass.from_pretrained( - args.pretrained_model_path, - # max_memory=max_memory, - # trust_remote_code=True, - load_in_8bit=(args.quantization=='8bit'), - load_in_4bit=(args.quantization=='4bit'), - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=args.low_cpu_mem_usage, # not for zero3 - use_safetensors=False, - quantization_config=BitsAndBytesConfig( - load_in_4bit=(args.quantization=='4bit'), - # llm_int8_threshold=6.0, - # llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - ) if args.quantization=='4bit' else None, - ) - - accelerator.print("load in 8bit: ", args.quantization=='8bit') - accelerator.print("load in 4bit: ", args.quantization=='4bit') - if args.peft_type == 'lora': - # for name, param in model.named_parameters(): - # # cast layer norm in fp32 for stability - # if param.ndim == 1 and "layer_norm" in name: - # param.data = param.data.to(torch.float32) - # if "lm_head" in name: - # param.data = param.data.to(torch.float32) - model.gradient_checkpointing_enable() - - elif args.peft_type == 'qlora': - # prepare base model for 8bit or 4bit model(cast non-8bit or non-4bit layers to fp32) - model = prepare_model_for_kbit_training(model) - logging.info(f"device map: {model.hf_device_map}") - - # Potentially load in the lora from a previous save - if not args.resume_from_checkpoint: - model = get_peft_model(model, peft_config) - else: - - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - # accelerator.load_state(args.resume_from_checkpoint) - model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True) - - t2 = time.time() - if accelerator.is_main_process: - logging.info(f"model loading time: {t2 - t1:.4f}") - model.print_trainable_parameters() - model.config.use_cache = False # silence the warnings. Please re-enable for inference! - model.config.use_logn_attn = False # special for qwen model - accelerator.print(model.config) - - # dataloader - train_dataloader = DataLoader( - train_dataset, shuffle=True, collate_fn=DataCollatorForMFTDataset(args), - batch_size=args.per_device_train_batch_size, pin_memory=True, drop_last=True - ) - valid_dataloader = DataLoader( - valid_dataset, collate_fn=DataCollatorForMFTDataset(args), batch_size=args.per_device_eval_batch_size, - pin_memory=True, drop_last=True - ) - - from deepspeed.ops.adam import FusedAdam as Adam - - adam_optimizer = Adam - optimizer = adam_optimizer( - model.parameters(), - weight_decay=args.weight_decay, - lr=args.learning_rate, - betas=(0.9, 0.95), - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) - - model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare( - model, train_dataloader, valid_dataloader, optimizer, lr_scheduler - ) - print(model.device) - accelerator.print(model) - # accelerator.print(model.config) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # zero 3 flag - is_ds_zero_3 = False - if getattr(accelerator.state, "deepspeed_plugin", None): - is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3 - accelerator.print(f"is_ds_zero_3: {is_ds_zero_3}") - - # Train! - accelerate_train(accelerator, model, train_dataloader, valid_dataloader, optimizer, lr_scheduler, num_update_steps_per_epoch, len(train_dataset), args) - - -if __name__ == "__main__": - main() diff --git a/mft_peft_hf/src/pefts/model_mapping.py b/mft_peft_hf/src/pefts/model_mapping.py deleted file mode 100644 index 0f01274..0000000 --- a/mft_peft_hf/src/pefts/model_mapping.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -# @author Chaoyu Chen -# @date 2023/6/19 - -""" -# from model.llama2.modeling_llama import LlamaForCausalLM -# from model.llama2.configuration_llama import LlamaConfig -from model.code_llama.modeling_llama import LlamaForCausalLM -from model.code_llama.configuration_llama import LlamaConfig -from model.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM -from model.baichuan.modeling_baichuan import BaichuanForCausalLM -from model.baichuan.configuration_baichuan import BaichuanConfig -from model.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM -from model.qwen.modeling_qwen import QWenLMHeadModel -from model.chatglm2.modeling_chatglm import ChatGLMForConditionalGeneration -from transformers import AutoModelForCausalLM - -MODEL_TYPES = { - "gpt_neox": GPTNeoXForCausalLM, - "llama": LlamaForCausalLM, - "baichuan": BaichuanForCausalLM, - "starcoder": GPTBigCodeForCausalLM, - 'qwen': QWenLMHeadModel, - 'chatglm2': ChatGLMForConditionalGeneration, -} - -MODEL_CONFIGS = { - "gpt_neox": None, - "llama": LlamaConfig, - "baichuan": BaichuanConfig, - "starcoder": None, - 'qwen': None, - 'chatglm2': None, -} - -QLORA_TARGETING_MODULES = { - "gpt_neox": ["query_key_value", 'dense', 'dense_h_to_4h', 'dense_4h_to_h'], - "llama": ["q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "baichuan": ["W_pack", "o_proj", "gate_proj", "down_proj", "up_proj"], - "starcoder": ["c_proj", "c_attn", "q_attn", "c_fc"], - "qwen": ["c_proj", "c_attn", "w1", "w2"], - "chatglm2": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], -} - -MODEL_SPECIAL_TOKENS = { - "gpt_neox": { - - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - - }, - "llama": { - - "eos_token": "", - "pad_token": "", - - }, - "baichuan": { - - "eos_token": "", - "pad_token": "", - - }, - "starcoder": { - - "eos_token": "<|endoftext|>", - "pad_token": "", - - }, - "qwen": { - - "eos_token": "<|endoftext|>", - "pad_token": "<|extra_1|>", - - }, - "chatglm2": { - - "eos_token": "", - "pad_token": "", - - }, -} \ No newline at end of file diff --git a/mft_peft_hf/src/pefts/train_utils.py b/mft_peft_hf/src/pefts/train_utils.py deleted file mode 100644 index 91e2ba5..0000000 --- a/mft_peft_hf/src/pefts/train_utils.py +++ /dev/null @@ -1,384 +0,0 @@ -""" -# @author Chaoyu Chen -# @date 2023/7/19 -# @module train_utils.py - -Hugging face accelerate + deepspeed zero stage2 + DP -QLoRA + MFT Training -""" - -import gc -import os -import sys -import threading -import argparse -import math -import logging -import json -import time -import transformers -import numpy as np -import psutil -import torch -from torch import nn -from tqdm.auto import tqdm -sys.path.append("..") -from utils.common_utils import generate_task_id, TASK2ID, ID2TASK -from utils.auto_accelerate_utils import loss_func_mft -from torch.utils.tensorboard import SummaryWriter -from accelerate.logging import get_logger -logger = get_logger(__name__) - -# Converting Bytes to Megabytes -def b2mb(x): - return int(x / 2 ** 20) - - -# This context manager is used to track the peak memory usage of the process -class TorchTracemalloc: - def __enter__(self): - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero - self.begin = torch.cuda.memory_allocated() - self.process = psutil.Process() - - self.cpu_begin = self.cpu_mem_used() - self.peak_monitoring = True - peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) - peak_monitor_thread.daemon = True - peak_monitor_thread.start() - return self - - def cpu_mem_used(self): - """get resident set size memory for the current process""" - return self.process.memory_info().rss - - def peak_monitor_func(self): - self.cpu_peak = -1 - - while True: - self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) - - # can't sleep or will not catch the peak right (this comment is here on purpose) - # time.sleep(0.001) # 1msec - - if not self.peak_monitoring: - break - - def __exit__(self, *exc): - self.peak_monitoring = False - - gc.collect() - torch.cuda.empty_cache() - self.end = torch.cuda.memory_allocated() - self.peak = torch.cuda.max_memory_allocated() - self.used = b2mb(self.end - self.begin) - self.peaked = b2mb(self.peak - self.begin) - - self.cpu_end = self.cpu_mem_used() - self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) - self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) - # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") - - -def accelerate_train(accelerator, model, train_dataloader, valid_dataloader, optimizer, lr_scheduler, num_update_steps_per_epoch, total_train_dataset_size, args): - # tensorboard writer - summary_writer = SummaryWriter(log_dir=args.tb_dir) - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - logger.info("**************************************** Running training ****************************************") - logger.info(f" Num examples = {total_train_dataset_size}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization(update/completed) steps = {args.max_train_steps}") - logger.info(f" Complete/Optimization steps per Epoch = {args.max_train_steps // args.num_train_epochs}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - - # 配置starting_epoch and completed_steps 从哪里开始训练 - completed_steps = 0 - starting_epoch = 0 - - if args.resume_from_checkpoint: - path = os.path.basename(args.resume_from_checkpoint) - # Extract `epoch_{i}` or `step_{i}` - training_difference = os.path.splitext(path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - completed_steps = starting_epoch * num_update_steps_per_epoch - print(f"resume from epoch {starting_epoch} and completed_steps {completed_steps}") - else: - # need to multiply `gradient_accumulation_steps` to reflect real steps - completed_steps = int(training_difference.replace("step_", "")) - starting_epoch = completed_steps // num_update_steps_per_epoch - resume_step = (completed_steps % num_update_steps_per_epoch) * args.gradient_accumulation_steps - print(f"resume from epoch {starting_epoch} resusme step {resume_step} and completed_steps {completed_steps}") - - # update the progress_bar if load from checkpoint - progress_bar.update(completed_steps) - - min_eval_loss = float('inf') - stall_num = 0 - - for epoch in range(starting_epoch, args.num_train_epochs): - if args.early_stopping and stall_num == args.early_stopping_stall_num: - break - with TorchTracemalloc() as tracemalloc: - model.train() - total_loss = 0 - reduce_loss = 0 - reduce_task_loss = torch.zeros(len(ID2TASK)).to(model.device) - reduce_task_exist = torch.zeros(len(ID2TASK)).to(model.device) - t3 = time.time() - if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: - # We skip the first `n` batches in the dataloader when resuming from a checkpoint - active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) - else: - active_dataloader = train_dataloader - tail_num = len(active_dataloader) - len(active_dataloader) % args.gradient_accumulation_steps - print(f"length of dataloader: {len(active_dataloader)}") - for step, batch in enumerate(active_dataloader): - if step == tail_num: - break - with accelerator.accumulate(model): - if step == 0: - accelerator.print(f"step 1 batch shape: {batch['input_ids'].shape},\n" - f"last 10 tokens: {batch['input_ids'][:, -10:]}" - f"last 10 loss mask: {batch['loss_mask'][:, -10:]}") - accelerator.print(f"first 10 tokens and loss_mask") - for pt in range(1): - accelerator.print(f"{batch['input_ids'][:, 10 * pt:10 * pt + 10]}") - accelerator.print(f"{batch['loss_mask'][:, 10 * pt:10 * pt + 10]}") - - t4 = time.time() - # forward & loss - outputs = model(input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - position_ids=batch['position_ids'], - return_dict=True, - ) - # loss - loss, task_loss, _ = loss_func_mft(outputs, - batch['labels'], - batch['task_mask'], - batch['task_id'], - args.weighted_loss_mode, - batch['loss_mask'] - ) - t5 = time.time() - - # backward - if not torch.isnan(loss): - total_loss += loss.detach().float() - accelerator.backward(loss) - - # update(sync_gradients) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - # support args.min_lr - if optimizer.param_groups[0]['lr'] <= args.min_lr: - optimizer.param_groups[0]['lr'] = args.min_lr - t6 = time.time() - # accumulate resuce_loss - if not torch.isnan(loss): - reduce_loss += loss.detach().float() - # accelerator.print("task loss devices: ", reduce_task_loss.device, task_loss.device) - reduce_task_loss += task_loss.detach().float() - reduce_task_exist += (task_loss != 0).detach().float() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - # progress_bar.update(1) - completed_steps += 1 - - # loggging 主进程打印所有卡平均loss - if completed_steps % args.log_interval == 0: - progress_bar.update(args.log_interval) - # gather reduce_loss and reduce_task_loss from all N devices - reduce_losses = accelerator.gather(reduce_loss).detach().float() - # reduce_task_losses = accelerator.gather_for_metrics(reduce_task_loss).reshape(-1, len(ID2TASK)) - # reduce_task_exists = accelerator.gather_for_metrics(reduce_task_exist).reshape(-1, len(ID2TASK)) - reduce_task_losses = accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK)) - reduce_task_exists = accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK)) - # get train loss and per-task train loss - train_loss = torch.mean(reduce_losses) / (args.log_interval * args.gradient_accumulation_steps) - # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (args.log_interval * args.gradient_accumulation_steps) - train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0) - t7 = time.time() - - # logging and tensorboard - logger.info( - f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}][train_task_loss={train_task_loss}]" - f"[gather shape={reduce_losses.shape}][lr={lr_scheduler.get_lr()[0]:.4e}, {optimizer.param_groups[0]['lr']:.4e}]", - # f"dataloader time: {t4 - t3:.4f}, forward time: {t5 - t4:.4f}, gather time: {t7 - t6:.4f}, backward time: {t6 - t5:.4f}", - main_process_only=True) - train_log_dict = {"training_loss": train_loss} - for i in range(len(ID2TASK)): - train_log_dict[f"{ID2TASK[i]}_train_loss"] = train_task_loss[i] - # accelerator.log(train_log_dict, step=completed_steps) - if accelerator.is_main_process: - for key, value in train_log_dict.items(): - summary_writer.add_scalar(f'{key}', value, completed_steps) - # summary_writer.close() - # accelerator.print(optimizer) - reduce_loss = 0 - reduce_task_loss = torch.zeros(len(ID2TASK)).to(model.device) - reduce_task_exist = torch.zeros(len(ID2TASK)).to(model.device) - # steps checkpointing - if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: - output_dir = f"step_{completed_steps}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - # accelerator.save_state(output_dir) - accelerator.wait_for_everyone() - logger.info( - f"[CHECKPOINT] saving lora checkpoint", - main_process_only=True) - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - output_dir, - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=accelerator.get_state_dict(model) - ) - accelerator.wait_for_everyone() - logger.info( - f"[CHECKPOINT][global_steps={step + 1}][complete_steps={completed_steps}], lora checkpoint {output_dir} saved", - main_process_only=True) - if completed_steps >= args.max_train_steps: - break - - accelerator.wait_for_everyone() - - # evaluation - if completed_steps % args.evalation_steps == 0: - model.eval() - losses = [] - accumulated_task_loss = torch.zeros(len(ID2TASK)).to(model.device) - accumulated_task_exist = torch.zeros(len(ID2TASK)).to(model.device) - for valid_step, valid_batch in enumerate(valid_dataloader): - # if valid_step > args.max_valid_steps: - # break - with torch.no_grad(): - outputs = model( - input_ids=valid_batch['input_ids'], - attention_mask=valid_batch['attention_mask'], - position_ids=valid_batch['position_ids'], - return_dict=True, - ) - - loss, task_loss, _ = loss_func_mft(outputs, valid_batch['labels'], - valid_batch['task_mask'], - valid_batch['task_id'], - args.weighted_loss_mode, - valid_batch['loss_mask']) - # losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) - losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) - # task_losses.append(accelerator.gather_for_metrics(task_loss)) - # [[1, 2, 3, 4, 1, 2, 3, 4], [1, 2, 3, 4, 1, 2, 3, 4]...] - # [[1, 2, 3, 4], .....] - accumulated_task_loss += task_loss.detach().float() - accumulated_task_exist += (task_loss != 0.0).detach().float() - - accelerator.wait_for_everyone() - # if accelerator.is_main_process: - valid_batch_num = len(losses) - gathered_size = losses[0].shape - losses = torch.cat(losses) - # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK)) - # task_losses = accelerator.gather_for_metrics(accumulated_task_loss).reshape(-1, len(ID2TASK)) - # task_exists = accelerator.gather_for_metrics(accumulated_task_exist).reshape(-1, len(ID2TASK)) - task_losses = accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK)) - task_exists = accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK)) - - try: - eval_loss = torch.mean(losses) - # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num - eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0) - if eval_loss <= min_eval_loss: - min_eval_loss = eval_loss - stall_num = 0 - else: - stall_num += 1 - perplexity = math.exp(eval_loss) - except OverflowError: - perplexity = float("inf") - - logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]" - f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]" - f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}][eval_task_loss={eval_task_loss}]", - main_process_only=True) - eval_log_dict = {"valid_loss": eval_loss, "perplexity": perplexity} - for i in range(len(ID2TASK)): - eval_log_dict[f"{ID2TASK[i]}_valid_loss"] = eval_task_loss[i] - # accelerator.log(eval_log_dict, step=completed_steps) - if accelerator.is_main_process: - for key, value in eval_log_dict.items(): - summary_writer.add_scalar(f'{key}', value, completed_steps) - - model.train() - if args.early_stopping and stall_num == args.early_stopping_stall_num: - accelerator.print(f"[WARNING] Early stopping at {completed_steps}") - break - - t3 = time.time() - - # epoch checkpointing - if args.epoch_checkpointing: - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - # accelerator.save_state(output_dir) - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - output_dir, - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=accelerator.get_state_dict(model) - ) - logger.info(f"[CHECKPOINGING], lora checkpoint {output_dir} saved", main_process_only=True) - accelerator.wait_for_everyone() - - # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage - accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin))) - accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used)) - accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked)) - accelerator.print( - "GPU Total Peak Memory consumed during the train (max): {}".format( - tracemalloc.peaked + b2mb(tracemalloc.begin) - ) - ) - - accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin))) - accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used)) - accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked)) - accelerator.print( - "CPU Total Peak Memory consumed during the train (max): {}".format( - tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin) - ) - ) - train_epoch_loss = total_loss / len(train_dataloader) - train_ppl = torch.exp(train_epoch_loss) - accelerator.print(f"{epoch=}: {train_ppl=} {train_epoch_loss=}") - - # end training if accelerator.init_trackers() - # accelerator.end_training() - summary_writer.close() - - # final save - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - os.path.join(args.output_dir, f"final_step_{completed_steps}"), - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=accelerator.get_state_dict(model) - ) - accelerator.wait_for_everyone() diff --git a/mft_peft_hf/src/tokenizer/__init__.py b/mft_peft_hf/src/tokenizer/__init__.py deleted file mode 100644 index 12ec210..0000000 --- a/mft_peft_hf/src/tokenizer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .tokenizer import build_tokenizer diff --git a/mft_peft_hf/src/tokenizer/tokenizer.py b/mft_peft_hf/src/tokenizer/tokenizer.py deleted file mode 100644 index 8680ba6..0000000 --- a/mft_peft_hf/src/tokenizer/tokenizer.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -# @author Chaoyu Chen -# @date 2023/6/19 -""" - - -import numpy as np -from typing import List, Union -from utils.common_utils import print_rank_0 -from transformers import AutoTokenizer - - -def build_tokenizer(args): - """Initialize tokenizer.""" - print_rank_0("> building {} tokenizer ...".format(args.tokenizer_type)) - # Select and instantiate the tokenizer. - if args.tokenizer_type.lower() == "AutoTokenizer".lower(): - assert args.pretrained_model_path is not None - tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True, use_fast=False) - tokenizer.eod_id = tokenizer.convert_tokens_to_ids(args.eos_token) - tokenizer.pad_id = tokenizer.convert_tokens_to_ids(args.pad_token) - print_rank_0(f"build_tokenizer PAD id: {tokenizer.pad_id}, EOD id: {tokenizer.eod_id}") - print_rank_0(f"build_tokenizer PAD token : {args.pad_token}, EOD token: {args.eos_token}") - else: - raise NotImplementedError( - "{} tokenizer is not " "implemented.".format(args.tokenizer_type) - ) - - # Add vocab size. - args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) - - return tokenizer - - -def _vocab_size_with_padding(orig_vocab_size, args): - """Pad vocab size so it is divisible by model parallel size and - still having GPU friendly size.""" - - after = orig_vocab_size - multiple = args.make_vocab_size_divisible_by * args.model_parallel_size - while (after % multiple) != 0: - after += 1 - print_rank_0( - " > padded vocab (size: {}) with {} dummy tokens " - "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after) - ) - - return after diff --git a/mft_peft_hf/src/utils/__init__.py b/mft_peft_hf/src/utils/__init__.py deleted file mode 100644 index 0cf9434..0000000 --- a/mft_peft_hf/src/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .common_utils import * -from .auto_accelerate_utils import * \ No newline at end of file diff --git a/mft_peft_hf/src/utils/auto_accelerate_utils.py b/mft_peft_hf/src/utils/auto_accelerate_utils.py deleted file mode 100644 index b2f9d9e..0000000 --- a/mft_peft_hf/src/utils/auto_accelerate_utils.py +++ /dev/null @@ -1,109 +0,0 @@ -import sys -sys.path.append("..") -import torch -from utils.common_utils import print_rank_0, TASK2ID, ID2TASK -from torch.nn import CrossEntropyLoss -from dataclasses import dataclass -import numpy as np - - -def get_task_mask(task_id): - task_num = len(TASK2ID) - task_mask = torch.zeros(task_id.shape[0], task_num) - task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1 - - return task_mask - - -def get_task_loss(task_losses, task_id): # TODO - # fix task order - task_loss_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device) - # count task samples - task_num_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device) - for i in range(len(task_id)): - task_num_per_batch[task_id[i][0]] += 1 - task_loss_per_batch[task_id[i][0]] = task_losses[task_id[i][0]] - - return task_loss_per_batch, task_num_per_batch - - -def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_mask=None): - """ - loss function for MFT loss - :param outputs: - :param labels: - :param task_mask: - :param task_id: - :param weighted_loss_mode: - :param loss_mask: - :return: - """ - # task_id shape: [[1], [2], [4], [3], ..., [1]] - weighted = weighted_loss_mode - lm_logits = outputs["logits"] - labels = labels.to(device=lm_logits.device) - task_mask = task_mask.to(device=lm_logits.device) - task_id = task_id.to(device=lm_logits.device) - shift_logits = lm_logits.contiguous() - labels = labels.contiguous() - - bsz, seq_len = labels.shape - # loss_mask = None - if loss_mask is None: - ineffective_tokens_per_sample = (labels==-100).sum(dim=1) - effective_tokens_per_sample = - (ineffective_tokens_per_sample - seq_len) - effective_tokens = bsz * seq_len - ineffective_tokens_per_sample.sum() - loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100) - else: - loss_mask = loss_mask.to(device=lm_logits.device) - loss_fct = CrossEntropyLoss(reduction='none') - losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) # [B * L, 1] - losses = losses.contiguous().view(bsz, -1) - token_losses = losses.clone().detach().float() if loss_mask is None else losses.clone().detach().float() * loss_mask # [B, L] - task_mask_trans = torch.transpose(task_mask, 0, 1) - if weighted_loss_mode == "case3" or weighted_loss_mode == "case4": - unique_id = torch.unique(task_id) - loss = 0.0 - - for i, w in enumerate(unique_id): - row_idx = torch.squeeze(task_id) == w.item() - if weighted_loss_mode == "case3": - if loss_mask is None: - loss += torch.sum(losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) - else: - loss += torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) - elif weighted_loss_mode == "case4": - if loss_mask is None: - loss += torch.mean(torch.sum(losses, dim=1)[row_idx] / effective_tokens_per_sample[row_idx]) - else: - loss += torch.mean(torch.sum(losses * loss_mask, dim=1)[row_idx] / torch.sum(loss_mask, dim=1)[row_idx]) - - loss /= len(unique_id) - - elif weighted_loss_mode == "case2": - if loss_mask is None: - loss = torch.mean(torch.sum(losses, dim=1) / effective_tokens_per_sample) - else: - loss = torch.mean(torch.sum(losses * loss_mask, dim=1) / torch.sum(loss_mask, dim=1)) - elif weighted_loss_mode == "case1": - # flatten losses & loss_mask tensor - if loss_mask is None: - losses = losses.view(-1) - loss = torch.sum(losses) / effective_tokens - else: - loss_mask = loss_mask.view(-1) - losses = losses.view(-1) - loss = torch.sum(losses * loss_mask) / loss_mask.sum() - - # fix task order - task_loss = torch.zeros(len(ID2TASK)).to(device=task_id.device) - task_num = torch.zeros(len(ID2TASK)).to(device=task_id.device) - for i, w in enumerate(unique_id): - row_idx = torch.squeeze(task_id) == w.item() - if loss_mask is None: - task_loss[w] = torch.sum(token_losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) - task_num[w] = len(effective_tokens_per_sample[row_idx]) - else: - task_loss[w] = torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) - - return loss, task_loss, task_num diff --git a/mftcoder_accelerate/README.md b/mftcoder_accelerate/README.md new file mode 100644 index 0000000..87b4b63 --- /dev/null +++ b/mftcoder_accelerate/README.md @@ -0,0 +1,441 @@ +# MFTCoder-accelerate: Training Framework with Accelerate and DeepSpeed/FSDP +[![Generic badge](https://img.shields.io/badge/🤗-Huggingface%20Repo-green.svg)](https://huggingface.co/codefuse-ai) + + GitHub + + +[[中文]](README_cn.md) [**English**] + +## 1. Updates +🔥 MFTCoder-accelerate now supports DPO/ORPO training through xxpo module. + +🔥 MFTCoder-accelerate now supports continue training through mpt module along with offline_tokenization module. + +🔥 MFTCoder-accelerate supports MFT with latest implementation of CoBa Loss (selfpaced Loss) for better Convergence Balance. + +🔥 MFTCoder-accelerate now support these modes: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, Full-parameter + DeepSpeed ZeRO3, QLoRA + FSDP, Full-parameter + FSDP. + +🔥 MFTCoder-accelerate supports QLoRA + DeepSpeed ZeRO3 and QLoRA + FSDP, which both work for larger models. + +🔥 MFTCoder-accelerate supports MFT/SFT on more new mainstream open-source base models: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3. + +🔥 MFTCoder-accelerate supports Self-Paced Loss for Convergence Balance. + +🔥 MFTCoder-accelerate supports Full-parameters/QLoRA/LoRA using accelerate + DeepSpeed Framework. + +🔥 MFTCoder-accelerate supports Multitask Fine-Tuning(MFT), which is able to balance diffenrent tasks in data level. + +🔥 MFTCoder-accelerate supports finetuning most of mainstream open-source base models: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen. + +## 2. Data Format +### 2.1 MFT Training Data Format +The training data is required to be a uniformed JSONL format, in which each line of data has the following "chatML"-style JSON format. The "chat_rounds" field is required, and other fields can be added or removed based on specific needs. +The reason why we selected "chatML" style as our training and inference data format is that "chatML" style is compatible with both "conversation" and "instruction/response" scenarios. + +For the keys of roles in "chat_rounds", you could use "system/human/bot" tuple or "system/user/assistant" tuple. + +```json +{ + "id":0, + "data_name":"code-helper", + "chat_rounds":[ + { + "role": "system", + "content": "You are a expert in coding and help answer code questions" + }, + { + "role": "human", + "content": "Write a python function of quick sort" + }, + { + "role": "bot", + "content": "Below is the function of quick sort: ..." + }, + { + "role": "human", + "content": "Explain the code" + }, + { + "role": "bot", + "content": "OK, this code ..." + } + ] +} +``` + +### 2.2 Default MFTCoder Inference Template +Inference data format is the real string format consumed by tokenizers and then LLMs. It is also the string format to which the training data is converted before tokenization. +The default inference data format contains strings concatenated by conversation data(system, human and bot contents) in the training data format. +It is used as the data "seen"(before tokenization) by the model in training process. +It is used as input during the inference process as well. +Here is an example format of the inference string: + +``` +""" +system +System instruction +human +User 1st round input +bot +Assistant 1st round output{EOS_TOKEN} +human +User 2nd round input +bot +Assistant 2nd round output{EOS_TOKEN} +... +... +... +human +User nth round input +bot +{Assistant output to be genreated}{EOS_TOKEN} +""" +``` +When applying inference, you always make your input string end with ```bot\n``` to request the model generating answers. + +### 2.3 DPO训练数据格式 +The training data is required to be a uniformed JSONL format, in which each line of data has the following JSON format. The "chosen" and "rejected" fields are required as ```chosen``` and ```rejected``` in DPO training and both includes "chatml-style" contents(only last content of bot differs). +```json +{ + "chosen":[ + { + "role": "system", + "content": "You are a expert in coding and help answer code questions" + }, + { + "role": "human", + "content": "Write a python function of quick sort" + }, + { + "role": "bot", + "content": "Below is the function of quick sort: ..." + }, + { + "role": "human", + "content": "Explain the code" + }, + { + "role": "bot", + "content": "OK, this code ..." + } + ], + "rejected":[ + { + "role": "system", + "content": "You are a expert in coding and help answer code questions" + }, + { + "role": "human", + "content": "Write a python function of quick sort" + }, + { + "role": "bot", + "content": "Below is the function of quick sort: ..." + }, + { + "role": "human", + "content": "Explain the code" + }, + { + "role": "bot", + "content": "Sorry, I can not answer..." + } + ] +} +``` + + +## 3. Model Training +Currently, the "MFTCoder-accelerate" codebase supports Full-parameters/LoRA/QLoR along with Multi-Task FineTuning(MFT). +In theory, this project can be used to train any publicly available model in the HuggingFace Format. + +Here are some excellent pre-trained models weights available on Huggingface that can be finetuned with this codebase: + +🤗 [Latest code pre-trained SOTA, CodeLlama-34b-Python](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) : code-llama-34b, code-llama-34b-python, a new SOTA base model. + +🤗 [Best 10B level pre-trained Code LLM, Starcoder:](https://huggingface.co/bigcode/starcoder) wizardCoder-15B, PanGu-coder2, and other previous SOTA were trained on it. + +🤗 [Multilingual powerhouse, Qwen-7b](https://huggingface.co/Qwen/Qwen-7B): Suitable for multilingual tasks, including Chinese tasks, for instruction fine-tuning. + +**mftcoder_accelerate directory structure** +``` +mftcoder_accelerate + | + src + configs + | + data + | + model + | + *pefts* + | + *xxpo* + | + *mpt* + | + *offline_tokenization* + | + tokenizer + | + utils + | + evals +``` +我们将训练中使用的各种组件抽取出来,以便后续的扩展和优化, 详见```src```目录下的实现。 + +MFT训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` + +DPO/ORPO训练入口文件是```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py``` + +MPT(全量加训)训练入口文件是```mftcoder_accelerate/src/mpt/mpt_accelerate.py``` + +参数配置存储在```mftcoder_accelerate/src/configs```目录下,方便统一管理和更改。 + +**_所以,在你开启训练之前,请进入src目录_** +``` +cd mftcoder_accelerate/src +``` + +You can find the implementations in the ```mftcoder_accelerate/src``` directory +The entry file for MFT training is ```mftcoder_accelerate/src/pefts/mft_accelerate.py```. + +The entry file for DPO/ORPO training is ```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py```. + +The entry file for MPT(Continue Training) is ```mftcoder_accelerate/src/mpt/mpt_accelerate.py```. You need finish offline tokenization of your data via ```mftcoder_accelerate/src/run_offline_tokenization.sh```, which is different from the online tokenizaion used in MFT/DPO. + +Configurations are stored in the ```mftcoder_accelerate/src/configs``` directory for easy management and modification. + +**_As a result, before you start training, you should first change your dir by_** +``` +cd mftcoder_accelerate/src +``` + +### 3.1 MFT Tokenization +During training, we concatenate multi-turn dialogues into the following format (also known as the inference data format mentioned before) and then tokenize it. + +In default format, ```human\n``` starts the user's input (i.e., prompt),```bot\n``` starts the assistant's output (i.e., response) + +```{EOS_TOKEN}``` represents the proper eos_token. +We have different eos_tokens in ```src/pefts/model_mapping.py``` which fits different base models. + +Here is a visionable example of the training data after formatting: +``` +f"human\n{input1}bot\n{target1}{EOS_TOKEN}\nhuman\n{input2}bot\ntarget2{EOS_TOKEN}\n" +``` +During the calculation of loss, we use a ```loss mask``` to ensure that the loss from the input part does not contribute to parameter updates. Only the loss from the ```target{EOS_TOKEN}``` part is used for updating parameters. +This approach takes full advantage of the benefits of model parallelism, making training more efficient. It also leverages the characteristic of decoder-only models with left-to-right attention. +By including all target parts from multiple turns in a single training iteration, the training process becomes more efficient. + + +### 3.2 LoRA/QLoRA + +#### Intro +You can refer to the Lora paper for details about LoRA:[LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf) + +You can refer to the Qlora paper for details about QLoRA:[QLORA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/pdf/2305.14314.pdf) + +QLoRA (Quantized LoRA) is a method that combines 4-bit nf4 quantization and additional adapters to achieve a balance between reducing GPU memory consumption and approaching the performance of full-parameter fine-tuning. + +According to the QLoRA paper, this method enables fine-tuning of a 33B model on a single V100 GPU while achieving performance close to that of full-parameter fine-tuning. + +To perform LoRA/QLoRA fine-tuning, you can execute the following command: + +#### Launch via Deepspeed +DeepSpeed config in accelerate_ds_config.yaml. +```bash +accelerate launch --config_file accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "DeepSpeed" +``` +or +DeepSpeed Zero2 config in command line arguments +```bash +sh ds_single_launch.sh +``` +DeepSpeed Zero3 config in command line arguments +```bash +sh ds_zero3_single_launch.sh +``` + +#### Launch via FSDP +FSDP config in accelerate_fsdp_config.yaml. +```bash +accelerate launch --config_file accelerate_fsdp_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "FSDP" +``` +or +FSDP config in command line arguments +```bash +sh fsdp_single_launch.sh +``` + +#### MultiNode Launch +Refer to the deepspeed multi-node launch script below. +```bash +sh ds_multinode_launch.sh +``` + +#### Traing Arguments +All arguments allowed in ***_train_config.josn are defined in ```arguments.py```. + +Frequently used arguments are provided in ```configs/***_train_config``` and explained as follows. You can modify these parameters according to your needs: + +- **load_raw_dataset**: Need to be true at present. Only JSONL format is supported. + +- **data_paths**: Input data paths in a String of list format, e.g., "[path1,path2,path3]". Each path represents a task directory and each task directory contains one or more JSONL data files. + +- **output_dir**: Training output directory to store checkpoints, Lora adapter, etc. + +- **tb_dir**: TensorBoard directory to store logs, metrics, etc. + +- **model_type**: Type of the model to train, e.g., "mixtral | llama | starcoder | chatglm2 | qwen | gpt_neox". + +- **attn_implementation**: "flash_attention_2" or "eager" or "sdpa", worked when model is supported by transformers officially + +- **peft_type**: null or "lora" or "qlora". null for full-params training + +- **lora_rank**: Rank value for Lora. + +- **lora_alpha**: Alpha value for Lora. + +- **lora_dropout**: Dropout rate for Lora. + +- **target_modules**: List of target modules in lora, we have default values if None + +- **quantization**: "4bit" for QLoRA/ null for LoRA and Full-params training. + +- **pretrained_model_path**: Local/Shared disk path or model name on HuggingFace for the pre-trained model. + +- **weighted_loss_mode**: Loss weighting method for multitask training. "case3" is recommended at present, "self-paced" is supported but need tuning of hyperparameters. + +- **padding_mode**: The way tokenized data is set. "padding" means padding for each sample to seq_length, "pack" means putting samples into seq_length as many as possible. + +- **num_train_epochs**: Number of training epochs. + +- **per_device_train_batch_size**: Batch size per GPU for training. + +- **per_device_eval_batch_size**: Batch size per GPU for evaluation. + +- **gradient_accumulation_steps**: Number of gradient accumulation steps. Global batch size is calculated as num_gpus * per_device_train_batch_size * gradient_accumulation_steps. + +- **learning_rate**: Initial Learning rate. For full-parameter fine-tuning, it is recommended to use a smaller value such as 1e-5 or 5e-6. For QLoRA, a larger learning rate is generally used, such as 1e-4 or 2e-4. + +- **min_lr**: Minimum learning rate. Usually set to one-tenth of the learning rate. + +- **seq_length**: Maximum input sequence length during training. + +- **log_interval**: Log training loss every ```log_interval``` steps. + +- **checkpointing_steps**: Save a checkpoint every ```checkpointing_steps``` steps. + +- **evaluation_steps**: Evaluate on the validation set every ```evaluation_steps``` steps. + +- **early_stopping**: Enable early stopping or not. + +- **early_stopping_stall_num**: Number of evaluation points without improvement which triggers early stopping. + +- **lr_scheduler_type**: Type of learning rate scheduler. "cosine" is a good choice already. + +- **num_warmup_steps**: Number of warm-up steps to gradually increase the learning rate. + +- **seed**: Random seed for reproducibility. + +- **saving_limit**: ckpt saving limit num, must be set in Full-parameter training. + +- **role_markers**: {"system": "\system\n", "user": "\human\n", "assistant": "\bot\n} as default(null). You could set your preferred role_markers as the templates startting "system", "user" and "assistant". e.g. {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} + +#### CoBa Arguments Configuration +- **coba_warmup_steps**: The number of warm-up steps for CoBa. During the warm-up period, all task weights are equal, and after the warm-up, weights begin to be adjusted dynamically. It is generally recommended to set this close to the total number of validation batches. +- **coba_history_length**: The historical window length of validation loss maintained by CoBa, used to fit the convergence slope at the current step. It is generally recommended to set this between 2 times and 5 times the **coba_warmup_steps**. Typically, the larger this value, the smaller the changes in weights will be. +- **coba_tau**: The temperature coefficient for the Divergence Factor (DF). It is generally set to 5. +- **coba_update_interval**: The frequency at which CoBa updates weights. It is commonly set to 1, meaning weights are updated at every step. +- **coba_sample_valid_num**: The number of validation batches to be sampled by CoBa at each step. Theoretically, when this value equals the total number of validation batches, the fitted convergence slope most closely approximates the actual situation. However, considering computational requirements, it is recommended to set it to 1. + +#### DPO Arguments Configuration +- **xxpo**: preference optimization type, "dpo" or "orpo". +- **beta**: DPO beta, smaller beta allows larger distance between dpo model and ref model. +- **rpo_alpha**: The coefficient of the ```chosen``` NLL loss added to dpo loss. + +## 4. Model Usage + +### 4.1 Merge Adaptor weights +Using LoRA or QLoRA for training, this project only saves the weights and configuration files of the adapters. +To merge the adapter weights with the base model: +``` +python pefts/merge_base_and_lora_to_hf.py \ + --base_model_or_path model_path \ + --adaptor_path lora_adapter_path \ + --model_type model_type \ + --merged_output_path output_path +``` + +### 4.2 Inference demo +Here is the script for inference on models trained by MFTCoder since v0.3.0, which is compatible with most HuggingFace models: +```python +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, +) +model_name_or_path = "codefuse-ai/CodeFuse-Deepseek-33B" +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side="left") +tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|end▁of▁sentence|>") +tokenizer.pad_token_id = tokenizer.eos_token_id +model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) + +HUMAN_ROLE_START_TAG = "human\n" +BOT_ROLE_START_TAG = "bot\n" +texts = ["write a python function of quick sort."] +texts = [f"{HUMAN_ROLE_START_TAG}{text}{BOT_ROLE_START_TAG}" for text in texts] + +inputs = tokenizer(texts, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda") +outputs = model.generate( + inputs=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=512, + top_p=0.95, + temperature=0.1, + do_sample=True, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + ) +gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +print(gen_text) +``` + + +Indeed, the parameters top_p, temperature, repetition_penalty, do_sample, etc., have a significant impact on the model's generation output. +You can modify these parameters based on your specific use case. + +In code generation scenarios, if you are using the sampling mode (do_sample=True), the following parameter settings can yield good results for the Pass@1 metric: + +top_p: Set a higher value, such as 0.95, to retain highly probable generated words. This helps ensure more accurate and fluent generation results. + +temperature: Set a lower value, such as 0.1, to reduce randomness. Lower temperature values make the generation output more deterministic. + +These parameter combinations can control the diversity of the generated outputs while maintaining naturalness. Additionally, you can adjust other related parameters, such as repetition_penalty, to reduce repetition in the generated results. + +If you choose the non-sampling mode (do_sample=False), you can consider the following parameter settings: + +beam_num: Set a smaller value such as 1 or 3. ```beam_num=1``` represents greedy decoding, which selects the most probable single generated word. ```beam_num=3``` represents beam search mode, which considers multiple potential generation paths and chooses the best path among them. + +## 5. FAQ +#### Q1:What should I do when cuda OOM happens? +If OOM happened,you can reduce parameters such as per_device_train_batch_size and seq_length. Since you are dealing with large models (6B, 13B, 34B, 70B, etc.), you are already using gradient checkpointing technology by default, which significantly reduces GPU memory consumption. +However, this may slightly slow down the training speed. + +QLoRA + DeepSpeed Zero3 is recommended for larger models to avoid OOM. + +#### Q2:install packages +Please refer to init_env.sh and requirements.txt +We highly recommend you install Flash Attention 2 (flash_attn>=2.1.0, 2.3.6 used by us) first to get memory-efficient and fast training. + +#### Q3:How should I specify the GPUs for training? +You can specify the visiable GPUs as below: +```bash +CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json +``` + +#### Q4:Whats is a recommended Distributed Training? +For LoRA, we recommend DeepSpeed ZeRO2 as the underlying framework, because it is easy and stable to use, moreover it is more compatable for different settings. + +For QLoRA, DeepSpeed ZeRO2 and DeepSpeed ZeRO3 are both good, moreover DeepSpeed ZeRO3 is a good choice for very large models. + +For Full-parameter finetuning, DeepSpeed ZeRO3 and FSDP are faster, and may help you with very large models by sharding parameters and gradients. \ No newline at end of file diff --git a/mftcoder_accelerate/README_cn.md b/mftcoder_accelerate/README_cn.md new file mode 100644 index 0000000..39631c5 --- /dev/null +++ b/mftcoder_accelerate/README_cn.md @@ -0,0 +1,385 @@ +# MFTCoder: Accelerate + DeepSpeed/FSDP 框架篇 +[![Generic badge](https://img.shields.io/badge/🤗-Huggingface%20Repo-green.svg)](https://huggingface.co/codefuse-ai) + + GitHub + + +[**中文**] [[English]](README.md) + +## 1. 更新 +🔥 MFTCoder-accelerate 增加了xxpo模块,支持dpo训练。 + +🔥 MFTCoder-accelerate 增加了mpt模块,借助offline_tokenization模块,支持全量参数加训。 + +🔥 MFTCoder-accelerate 增加了CoBa Loss的最新实现(原selfpaced Loss), 让收敛均衡更进一步。 + +🔥 MFTCoder-accelerate 最新支持的训练模式包括: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, 全量 + DeepSpeed ZeRO3, QLoRA + FSDP, 全量 + FSDP。 + +🔥 MFTCoder-accelerate 新增支持QLoRA + DeepSpeed ZeRO3, 支持QLoRA + FSDP, 可以训练更大的模型。 + +🔥 MFTCoder-accelerate 新增支持accelerate + FSDP框架, 支持全量微调和LoRA。 + +🔥 MFTCoder-accelerate 支持最新更多主流开源模型: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3。 + +🔥 MFTCoder-accelerate 新增self-paced Loss, 用于收敛均衡。 + +🔥 MFTCoder-accelerate 支持使用accelerate + DeepSpeed框架下支持 全量参数/QLoRA/LoRA微调。 + +🔥 MFTCoder-accelerate 在训练中支持了多任务微调MFT, 可以同时平衡多个任务的训练,训练的模型支持多任务推理。 + +🔥 MFTCoder-accelerate 在训练中支持多种模型基座: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen等。 + +## 2. 数据格式 +### 2.1 MFT训练数据格式 +训练数据为jsonl格式,每一行的数据格式如下,其中chat_rounds字段是必需的,可以根据实际需求添加或删除其他字段。 +可以参考项目中的xxx.jsonl文件。 +```json +{ + "id":0, + "data_name":"code-helper", + "chat_rounds":[ + { + "role": "system", + "content": "你是一个智能代码助手,可以回复用户与代码相关的问题" + }, + { + "role": "human", + "content": "写一个快速排序" + }, + { + "role": "bot", + "content": "以下是一个快速排序算法xxxxxx" + }, + { + "role": "human", + "content": "解释一下这段代码" + }, + { + "role": "bot", + "content": "好的,这段代码xxx" + } + ] +} +``` + +### 2.2 推理数据格式 +推理数据格式为模型在训练数据格式下拼接的字符串形式,它也是推理时输入prompt拼接的方式: +``` +""" +system +这是System指令 +human +这是第1轮用户输入的问题 +bot +这是第1轮模型生成的内容{EOS_TOKEN} +human +这是第2轮用户输入的问题 +bot +这是第2轮模型生成的内容{EOS_TOKEN} +... +... +... +human +这是第n轮用户输入的问题 +bot +{模型现在要生成的内容}{EOS_TOKEN} +""" +``` + +### 2.3 DPO训练数据格式 +训练数据为jsonl格式,每一行的数据格式如下,其中chosen字段和rejected字段分别代表偏好对齐中的```chosen```和```rejected```,其内部依然是MFT的chatml格式,并且只有最后一轮对话的bot content不同。 +```json +{ + "chosen":[ + { + "role": "system", + "content": "你是一个智能代码助手,可以回复用户与代码相关的问题" + }, + { + "role": "human", + "content": "写一个快速排序" + }, + { + "role": "bot", + "content": "以下是一个快速排序算法xxxxxx" + }, + { + "role": "human", + "content": "解释一下这段代码" + }, + { + "role": "bot", + "content": "好的,这段代码xxx" + } + ], + "rejected":[ + { + "role": "system", + "content": "你是一个智能代码助手,可以回复用户与代码相关的问题" + }, + { + "role": "human", + "content": "写一个快速排序" + }, + { + "role": "bot", + "content": "以下是一个快速排序算法xxxxxx" + }, + { + "role": "human", + "content": "解释一下这段代码" + }, + { + "role": "bot", + "content": "对不起,我不会" + } + ] +} +``` + + +## 3. 模型训练 +目前支持全量参数(Full-parameters)指令微调、QLoRA指令微调,LoRA指令微调。 +一些优秀的代码预训练模型权重,理论上,HuggingFace上开源的模型,均可使用本项目进行训练: + +🤗 [最新代码预训练SOTA,CodeLlama](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) :code-llama-34b, code-llama-34b-python, 新的SOTA基座。 + +🤗 [10B级别最佳代码预训练模型Starcoder](https://huggingface.co/bigcode/starcoder) wizardCoder-15B, PanGu-coder2等前SOTA的基座模型。 + +🤗 [多语言能手Qwen-7b](https://huggingface.co/Qwen/Qwen-7B) :适用于多语言任务,也适用中文任务。进行指令微调时。 + +**mftcoder_accelerate文件结构** +``` +mftcoder_accelerate + | + src + configs + | + data + | + model + | + *pefts* + | + *xxpo* + | + *mpt* + | + *offline_tokenization* + | + tokenizer + | + utils + | + evals +``` +我们将训练中使用的各种组件抽取出来,以便后续的扩展和优化, 详见```src```目录下的实现。 + +MFT训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` + +DPO/ORPO训练入口文件是```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py``` + +MPT(全量加训)训练入口文件是```mftcoder_accelerate/src/mpt/mpt_accelerate.py```. MPT加训需要提前做好数据的tokenziation,通过```mftcoder_accelerate/src/run_offline_tokenization.sh```,你可以将数据通过cpu进行离线的tokenization。这和MFT/DPO中使用的在线tokenziation不同。 + +参数配置存储在```mftcoder_accelerate/src/configs```目录下,方便统一管理和更改。 + +**_所以,在你开启训练之前,请进入src目录_** +``` +cd mftcoder_accelerate/src +``` + + + +### 3.1 数据tokenization +MFT/DPO训练时,我们将多轮对话拼接成如下格式(也是上文中的推理数据格式),然后进行tokenize。 +其中,默认情况下: + +```human\n```作为human/user的起始符,```bot\n```作为bot/assistant的起始符,```{EOS_TOKEN}``` 表示eos_token。 +其中eos_token可以根据不同模型修改替换。不同角色的起始符可以配置,用来实现不同的对话/问答模版。 +``` +"human\n{input1}bot\n{target1}{EOS_TOKEN}human\n{input2}bot\n{target2}{EOS_TOKEN}\n" +``` +在计算loss时,我们通过loss mask的方式,input部分的loss不参与参数更新,只有“target{EOS_TOKEN}”部分的loss参与参数更新。 +这种方式充分利用了模型并行计算的优势,训练更加高效,同时也充分利用了decoder-only模型从左到右attention的特性,一次性将多轮对话中的每个target部分都参与了训练,训练更充分高效。 + +### 3.2 LoRA/QLoRA微调 + +#### LoRA/QLoRA微调简介 +关于LoRA的详细介绍可参考论文:[LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf) + +关于QLoRA的详细介绍可参考论文:[QLORA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/pdf/2305.14314.pdf) + +QLoRA通过4-bit的nf4量化,且加入更多adapter,在大幅减少显存消耗的同时,尽可能逼近全量参数微调的效果。 +QLoRA论文指出,该方法可以在一张V100上对33B的模型进行微调,并且性能逼近全量参数微调。 + +执行如下命令即可进行 Lora/QLora/全量 微调: +#### Deepspeed 单机启动 +DeepSpeed配置在accelerate_ds_config.yaml中。 +```bash +accelerate launch --config_file accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "DeepSpeed" +``` +或者 + +DeepSpeed Zero2 配置在脚本中通过命令行输入。 +```bash +sh ds_single_launch.sh +``` + +DeepSpeed Zero3 配置在脚本中通过命令行输入 +```bash +sh ds_zero3_single_launch.sh +``` + +#### FSDP 单机启动 +FSDP配置在accelerate_fsdp_config.yaml中。 +```bash +accelerate launch --config_file accelerate_fsdp_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "FSDP" +``` +或者 + +FSDP配置在脚本中通过命令行输入。 +```bash +sh fsdp_single_launch.sh +``` + +#### 多机启动 +多机启动请参考如下deepspeed多机启动脚本 +```bash +sh ds_multinode_launch.sh +``` + +#### 训练参数 +_**训练需要的参数配置在```configs/*_train_config```中,主要参数说明如下:**_ + +- **load_raw_dataset**: 需要保持true,后续会支持其它模式数据,当前仅支持jsonl输入 +- **data_paths**: "[path1,path2,path3]" 输入数据地址,字符串,开头结尾用[],中间用```,```间隔不同path,每个path是一个目录,目录的最后一级名字作为任务名称,下面包含1到多个jsonl数据 +- **output_dir**:训练输出目录,存储checkpoint(全量训练时)、lora_adaptor(Lora或者Qlora时)等 +- **tb_dir**: 存储tensorboard等 +- **model_type**: "mixtral|mistral|deepseek|llama|starcoder|chatglm2|qwen|gpt_neox" +- **attn_implementation**: "flash_attention_2" 或者 "eager" +- **peft_type**: lora或者qlora或者null(全量微调) +- **lora_rank**: lora rank +- **lora_alpha**: lora alpha +- **lora_dropout**: lora dropout +- **target_modules**: List[str], lora目标模块,如果null,会使用默认,参考model_mapping.py +- **quantization**: 是否量化,"4bit", "8bit" 或者null, qlora推荐4bit量化 +- **pretrained_model_path**:预训练模型的本地目录,或者在huggingface上的模型名称。 +- **weighted_loss_mode**: 多任务loss加权模式, "case3"是当前推荐。 +- **padding_mode**: 数据的样本组织方式, "padding"是将每个原始样本填充到seq_length, "pack"是将尽量多的样本打包到每个seq_length的序列中。 +- **num_train_epochs**:训练的轮次。如果数据量足够大,一般建议只训1-2个epoch。 +- **per_device_train_batch_size**:每张显卡train的batch size。 +- **per_device_eval_batch_size**:每张显卡eval的batch size。 +- **gradient_accumulation_steps**:梯度累计步数。global batch=num_gpus * per_device_train_batch_size * gradient_accumulation_steps。 +- **learning_rate**:学习率。全量参数微调的时候,建议小一些,1e-5或5e-6。qlora中的学习率设置更大一些,一般为1e-4、2e-4。 +- **min_lr**: 最低学习率, 一般是learning_rate的十分之一 +- **seq_length**:训练时的最大长度。按照自己的设备进行设置,越长需要占用越多显存。 +- **log_interval**:每隔多少步统计一次train loss。 +- **checkpointing_steps**:每隔多少步保存一个模型。 +- **evaluation_steps**:每隔多少步在验证集上evaluate一次。 +- **early_stopping** : 是否执行early_stop +- **early_stopping_stall_num**: 多少个eval point不继续收敛,则停止训练 +- **lr_scheduler_type**:学习率变化策略。常用"cosine" +- **warmup_steps**:warm up步数。学习率经过多少步,增长到指定的数值。 +- **seed**:随机种子,用于复现实验结果。 +- **saving_limit**:整数,ckpt存储数量上限, 全量训练必须设置。默认null即不限制数量。 +- **role_markers**: null,即使用{"system": "\system\n", "user": "\human\n", "assistant": "\bot\n"}。 你可以自定义 "system", "user" and "assistant"的模板, 用于定制自己的问答或者对话模板,比如 {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} + +#### CoBa相关参数配置 +- **coba_warmup_steps**: CoBa的warm-up步数。在warm-up期间,各任务权重相等,warm-up之后,开始动态调整权重。一般建议设置为与valid batch总数量相近即可。 +- **coba_history_length**: CoBa维护的valid loss的历史窗口长度,用于拟合当前步收敛斜率。一般建议设置为2倍**coba_warmup_steps**至5倍**coba_warmup_steps**之间。一般该值越大,权重的变化幅度就会越小。 +- **coba_tau**: 发散因子(DF)的温度系数。一般设置为5即可。 +- **coba_update_interval**: CoBa更新权重的频率。一般设置为1,即每一步都对权重做更新。 +- **coba_sample_valid_num**: CoBa每一步要取的valid batch数。理论上当该值等于valid batch总数量时,拟合出的收敛斜率最逼近真实情况,但考虑到计算需求,建议设置为1。 + +#### DPO 相关参数配置 +- **xxpo**: 偏好对齐方法, "dpo" 或者 "orpo"。 +- **beta**: DPO beta, beta 越小,允许对齐后的dpo模型与ref模型的距离越远。 +- **rpo_alpha**: 加到dop损失的```chosen``` NLL损失的系数,0的话就是原始DPO。 +- +## 4. 模型使用 + +### 4.1 权重合并 +如果使用LoRA或者QLoRA进行训练,本项目仅保存adapter的权重和配置文件,需要将adapter权重与base model进行合并。 +可以使用如下merge_base_and_lora_to_hf.py脚本。 +``` +python pefts/merge_base_and_lora_to_hf.py \ + --base_model_or_path model_path \ + --adaptor_path lora_adapter_path \ + --model_type model_type \ + --merged_output_path output_path +``` + +### 4.2 模型推理 +我们提供了单轮对话和多轮对话的如下脚本,该脚本可同时兼容大部分huggingface格式的模型。 +```python +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, +) +model_name_or_path = "codefuse-ai/CodeFuse-Deepseek-33B" +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side="left") +tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|end▁of▁sentence|>") +tokenizer.pad_token_id = tokenizer.eos_token_id +model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) + +HUMAN_ROLE_START_TAG = "human\n" +BOT_ROLE_START_TAG = "bot\n" +texts = ["write a python function of quick sort."] +texts = [f"{HUMAN_ROLE_START_TAG}{text}{BOT_ROLE_START_TAG}" for text in texts] + +inputs = tokenizer(texts, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda") +outputs = model.generate( + inputs=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=512, + top_p=0.95, + temperature=0.1, + do_sample=True, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + ) +gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +print(gen_text) +``` + + +生成脚本中的top_p、temperature、repetition_penalty、do_sample等参数对模型的生成效果影响较大,可按照自己的使用场景进行调试修改。 +实践中,在代码生成场景中,如果采样模式,do_sample=True, top_p=0.95, temperature=0.1是pass@1指标的不错选择; +如果非采样模式, do_sample=False, beam_num=1或者3是不错的选择,其中beam_num=1即为greedy decoding。 + +## 5. FAQ +#### 问题1:OOM如何解决? +如果发生OOM,可以缩小per_device_train_batch_size、seq_length等参数来缓解。由于面对的模型普遍较大(6b, 13b, 34b, 70b等)我们已经默认使用gradient_checkpointing技术,可以大幅降低显存占用,但训练速度会稍慢一些。 +如果是模型太大,可以使用QLoRA + DeepSpeed ZeRO3(配置 zero stage = 3),这个方案可以在卡数足够的情况下,微调更大的模型。 +#### 问题2:安装包错误 +参考init_env.sh和requirements.txt + +#### 问题3:如何指定使用某些卡训练? +通过如下方式,即可指定使用0和1号卡进行训练: +```bash +CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file pefts/accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "deepspeed" +``` + +#### 问题4:关于Flash Attention, 该如何配置训练? +首先,我们强烈建议您安装Flash Attention 2(FA2),(>=2.1.0, 2.3.6功能更齐全)。 + +训练参数中"attn_implementation" 设置成 "eager" 可以用naive attention,也就是未经加速的attention。 + +训练参数中"attn_implementation" 设置成 "flash_attention_2" 可以用FA2,速度快,省显存。 + +如果你可以自行安装环境并使用torch>=2.1.1,可以尝试设置参数"attn_implementation"为 "sdpa"。这样会尝试使用transformers兼容的torch.nn.functional.scaled_dot_product_attention。支持的模型还不全面。 + +#### 问题5:推荐的分布式框架是怎样的? +对于LoRA, 我们推荐使用DeepSpeed Zero2作为底层分布式框架,它具有易用性和兼容性好的特点,并且速度很快, 模型加载模式上类似DDP。 +对于QLoRA, DeepSpeed Zero2 适合中小模型, DeepSpeed Zero3 适合很大的模型。 + +对于全量微调,可以使用DeepSpeed Zero3, 或者FSDP。二者都是Fully Sharding模式,即模型加载平分在每张卡。 + +#### 问题6:当前支持的模型中,有什么区别 +国产大模型比如chatglm2, chatglm3, baichuan2, qwen, aquila2等,使用的是和模型共同发布的modeling_xxx.py. +其它被transformers官方支持的大模型,比如llama, qwen2, starcoder2, mistral等,全面切换到官方的modeling支持训练,之前的自定义modeling会被deprecated。 + + + + + diff --git a/mftcoder_accelerate/inference/hf_inference.py b/mftcoder_accelerate/inference/hf_inference.py new file mode 100644 index 0000000..67f9ba0 --- /dev/null +++ b/mftcoder_accelerate/inference/hf_inference.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# @author Chaoyu Chen +# @date 2024/1/4 +# @module hf_inference.py +""" +# @author qumu +# @date 2023/9/19 +# @module hf_inference.py +""" +import os +import sys +import torch +import textwrap +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList +from peft import PeftModel + + +def load_model_tokenizer( + path, + model_type=None, + peft_path=None, + torch_dtype=torch.bfloat16, + quantization=None, + eos_token=None, + pad_token=None, + batch_size=1, +): + """ + load model and tokenizer by transfromers + """ + + # load tokenizer first + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer.padding_side = "left" + + config, unused_kwargs = AutoConfig.from_pretrained(path, trust_remote_code=True, return_unused_kwargs=True) + print("unused_kwargs:", unused_kwargs) + print("config input:\n", config) + + # eos token parsing + if eos_token: + eos_token = eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) + print(f"eos_token {eos_token} from user input") + elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: + print(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer") + eos_token_id = tokenizer.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(eos_token_id) + elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token: + print(f"Initial eos_token {tokenizer.eos_token} from tokenizer") + eos_token = tokenizer.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) + elif hasattr(config, "eos_token_id") and config.eos_token_id: + print(f"Initial eos_token_id {config.eos_token_id} from config.json") + eos_token_id = config.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id) + elif hasattr(config, "eos_token") and config.eos_token: + print(f"Initial eos_token {config.eos_token} from config.json") + eos_token = config.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token) + else: + raise ValueError( + "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json" + ) + + try: + tokenizer.eos_token = eos_token + tokenizer.eos_token_id = eos_token_id + # set pad_token to be same as eos_token, it is ok because is will be masked out. + tokenizer.pad_token = eos_token + tokenizer.pad_token_id = eos_token_id + except: + print(f"[WARNING]Cannot set tokenizer.eos_token") + + print(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}") + print(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}") + print(type(tokenizer)) + + base_model = AutoModelForCausalLM.from_pretrained( + path, + config=config, + load_in_8bit=(quantization == "8bit"), + load_in_4bit=(quantization == "4bit"), + device_map="auto", + torch_dtype=torch_dtype, + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + + if peft_path: + print("Loading PEFT MODEL...") + model = PeftModel.from_pretrained(base_model, peft_path) + else: + print("Loading Original MODEL...") + model = base_model + + model.eval() + + print("=======================================MODEL Configs=====================================") + print(model.config) + print("=========================================================================================") + print("=======================================MODEL Archetecture================================") + print(model) + print("=========================================================================================") + + return model, tokenizer + + +def hf_inference(model, tokenizer, text_list, args=None, max_new_tokens=512, do_sample=True, **kwargs): + """ + transformers models inference by huggingface + """ + # text_list = [tokenizer.apply_chat_template([{"role": "user", "content": text}], tokenize=False) for text in text_list] + inputs = tokenizer(text_list, return_tensors="pt", padding=True, add_special_tokens=False).to("cuda") + # inputs["attention_mask"][0][:100] = 0 + # print(inputs) + print("================================Prompts and Generations=============================") + + outputs = model.generate( + inputs=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=max_new_tokens, + do_sample=do_sample, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + **kwargs, + ) + + gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + for i in range(len(text_list)): + print("=========" * 10) + print(f"Prompt:\n{text_list[i]}") + gen_text[i] = gen_text[i].replace(tokenizer.pad_token, "") + print(f"Generation:\n{gen_text[i]}") + # print(f"Outputs ids:\n{outputs[i]}") + sys.stdout.flush() + + return gen_text + + +if __name__ == "__main__": + # Default template used in MFTCoder training + HUMAN_ROLE_START_TAG = "human\n" + BOT_ROLE_START_TAG = "bot\n" + + instruction = "Write quick sort function in python." + + prompts = [f"{HUMAN_ROLE_START_TAG}{instruction}\n{BOT_ROLE_START_TAG}"] + + # if you use base + adaptor for inference, provide peft_path or left it None for normal inference + base_model = "path/to/basemodel" + peft_path = None + model, tokenizer = load_model_tokenizer( + base_model, model_type="", peft_path=peft_path, eos_token="", pad_token="" + ) + + # hf_inference(model, tokenizer, prompts, do_sample=False, num_beams=1, num_return_sequences=1) + hf_inference(model, tokenizer, prompts, do_sample=True, temperature=0.8) diff --git a/mft_peft_hf/src/pefts/accelerate_ds_config.yaml b/mftcoder_accelerate/src/accelerate_ds_config.yaml similarity index 95% rename from mft_peft_hf/src/pefts/accelerate_ds_config.yaml rename to mftcoder_accelerate/src/accelerate_ds_config.yaml index 882eafb..dc2ece7 100644 --- a/mft_peft_hf/src/pefts/accelerate_ds_config.yaml +++ b/mftcoder_accelerate/src/accelerate_ds_config.yaml @@ -7,6 +7,7 @@ deepspeed_config: zero3_init_flag: false zero3_save_16bit_model: true zero_stage: 2 + # steps_per_print: 1 distributed_type: DEEPSPEED downcast_bf16: 'no' dynamo_backend: 'NO' diff --git a/mftcoder_accelerate/src/accelerate_fsdp_config.yaml b/mftcoder_accelerate/src/accelerate_fsdp_config.yaml new file mode 100644 index 0000000..2846853 --- /dev/null +++ b/mftcoder_accelerate/src/accelerate_fsdp_config.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: FSDP +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_offload_params: false + fsdp_sharding_strategy: 1 + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/mft_peft_hf/src/pefts/configs/gpt_neox_train_config.json b/mftcoder_accelerate/src/configs/coba_train_config.json similarity index 52% rename from mft_peft_hf/src/pefts/configs/gpt_neox_train_config.json rename to mftcoder_accelerate/src/configs/coba_train_config.json index 3ea6a54..63167f1 100644 --- a/mft_peft_hf/src/pefts/configs/gpt_neox_train_config.json +++ b/mftcoder_accelerate/src/configs/coba_train_config.json @@ -1,52 +1,46 @@ { - "load_raw_dataset": true, "data_paths": "$DATA_PATHS", "output_dir": "$OUTPUT_DIR", "tb_dir": "$TensorBoard_DIR", "pretrained_model_path": "$MODEL_NAME_OR_PATH", - "vocab_file": "$MODEL_NAME_OR_PATH", - "low_cpu_mem_usage": true, + "model_type": "$MODEL_TYPE", + "load_raw_dataset": true, "data_split": "95,5,0", "padding_mode": "padding", + "use_dynamic_padding": true, "tokenize_mode": "sft", - "weighted_loss_mode": "case3", - "shuffle_before_split": true, - "use_random_sampler": true, - "early_stopping": true, - "early_stopping_stall_num": 5, - "weight_by_num_documents": true, - "make_vocab_size_divisible_by": 128, - "model_parallel_size": 1, - "model_type": "gpt_neox", - "peft_type": "lora", - "lora_rank": 32, + "tokenizer_type": "AutoTokenizer", + "weighted_loss_mode": "coba", + "coba_warmup_steps": 100, + "coba_history_length": 200, + "coba_tau": 5, + "coba_update_interval": 1, + "coba_sample_valid_num": 1, + "attn_implementation": "flash_attention_2", + "seq_length": 4096, + "seed": 1234, + "peft_type": "qlora", + "quantization": "4bit", + "lora_rank": 96, "lora_alpha": 32, "lora_dropout": 0.05, - "quantization": "16bit", - "tokenizer_type": "AutoTokenizer", - "use_slow_tokenizer": false, - "use_xformers": true, - "trust_remote_code": true, - "use_dynamic_padding": true, "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, - "world_size": 128, - "learning_rate": 1e-4, - "min_lr": 1e-5, + "learning_rate": 5e-5, + "min_lr": 5e-6, "weight_decay": 0.1, "gradient_accumulation_steps": 1, "lr_scheduler_type": "cosine", - "num_warmup_steps": 30, - "num_train_epochs": 8, - "seed": 42, - "seq_length": 4096, - "preprocessing_num_workers": 2, - "num_workers": 2, + "num_warmup_steps": 300, + "num_train_epochs": 4, "resume_from_checkpoint": null, "log_interval": 10, - "checkpointing_steps": 500, - "evalation_steps": 500, + "checkpointing_steps": 100, + "evaluation_steps": 100, "max_train_steps": null, "epoch_checkpointing": true, - "checkpoint_activations": true -} \ No newline at end of file + "shuffle_before_split": true, + "early_stopping": true, + "early_stopping_stall_num": 5, + "saving_limit": null + } \ No newline at end of file diff --git a/mftcoder_accelerate/src/configs/dpo_train_config.json b/mftcoder_accelerate/src/configs/dpo_train_config.json new file mode 100644 index 0000000..5a93db9 --- /dev/null +++ b/mftcoder_accelerate/src/configs/dpo_train_config.json @@ -0,0 +1,34 @@ +{ + "xxpo": "dpo", + "data_paths": "$DATA_PATHS", + "output_dir": "$OUTPUT_DIR", + "tb_dir": "$TensorBoard_DIR", + "pretrained_model_path": "$MODEL_NAME_OR_PATH", + "model_type": "$MODEL_TYPE", + "data_split": "99,1", + "attn_implementation": "flash_attention_2", + "beta": 0.1, + "rpo_alpha": 0.5, + "peft_type": "lora", + "lora_rank": 64, + "lora_alpha": 128, + "lora_dropout": 0.0, + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, + "tokenizer_type": "AutoTokenizer", + "dataset_num_proc": 1, + "learning_rate": 5e-7, + "weight_decay": 0.01, + "gradient_accumulation_steps": 8, + "lr_scheduler_type": "cosine", + "warmup_steps": 100, + "num_train_epochs": 2, + "seed": 1105, + "max_prompt_length": 2048, + "max_length": 4096, + "logging_steps": 20, + "save_steps": 500, + "eval_steps": 500, + "epoch_checkpointing": false, + "saving_limit": 5 +} \ No newline at end of file diff --git a/mftcoder_accelerate/src/configs/full_train_config.json b/mftcoder_accelerate/src/configs/full_train_config.json new file mode 100644 index 0000000..a63c8dd --- /dev/null +++ b/mftcoder_accelerate/src/configs/full_train_config.json @@ -0,0 +1,37 @@ +{ + "data_paths": "$DATA_PATHS", + "output_dir": "$OUTPUT_DIR", + "tb_dir": "$TensorBoard_DIR", + "pretrained_model_path": "$MODEL_NAME_OR_PATH", + "model_type": "$MODEL_TYPE", + "load_raw_dataset": true, + "data_split": "98,2,0", + "padding_mode": "padding", + "use_dynamic_padding": true, + "tokenize_mode": "sft", + "tokenizer_type": "AutoTokenizer", + "weighted_loss_mode": "case3", + "attn_implementation": "flash_attention_2", + "seq_length": 4096, + "seed": 1234, + "peft_type": null, + "per_device_train_batch_size": 2, + "per_device_eval_batch_size": 2, + "learning_rate": 5e-5, + "min_lr": 5e-6, + "weight_decay": 0.1, + "gradient_accumulation_steps": 1, + "lr_scheduler_type": "cosine", + "num_warmup_steps": 300, + "num_train_epochs": 4, + "resume_from_checkpoint": null, + "log_interval": 10, + "checkpointing_steps": 100, + "evaluation_steps": 100, + "max_train_steps": null, + "epoch_checkpointing": true, + "shuffle_before_split": true, + "early_stopping": true, + "early_stopping_stall_num": 5, + "saving_limit": 3 +} \ No newline at end of file diff --git a/mft_peft_hf/src/pefts/configs/chatglm2_train_config.json b/mftcoder_accelerate/src/configs/lora_train_config.json similarity index 51% rename from mft_peft_hf/src/pefts/configs/chatglm2_train_config.json rename to mftcoder_accelerate/src/configs/lora_train_config.json index f0097cc..2034342 100644 --- a/mft_peft_hf/src/pefts/configs/chatglm2_train_config.json +++ b/mftcoder_accelerate/src/configs/lora_train_config.json @@ -1,49 +1,41 @@ { - "load_raw_dataset": true, "data_paths": "$DATA_PATHS", "output_dir": "$OUTPUT_DIR", "tb_dir": "$TensorBoard_DIR", "pretrained_model_path": "$MODEL_NAME_OR_PATH", - "vocab_file": "$MODEL_NAME_OR_PATH", - "low_cpu_mem_usage": true, - "data_split": "95,5,0", - "padding_mode": "pack", + "model_type": "$MODEL_TYPE", + "load_raw_dataset": true, + "data_split": "98,2,0", + "padding_mode": "padding", + "use_dynamic_padding": true, "tokenize_mode": "sft", + "tokenizer_type": "AutoTokenizer", "weighted_loss_mode": "case3", - "model_type": "chatglm2", + "attn_implementation": "flash_attention_2", + "seq_length": 4096, + "seed": 1234, "peft_type": "lora", - "lora_rank": 32, + "quantization": null, + "lora_rank": 96, "lora_alpha": 32, "lora_dropout": 0.05, - "quantization": null, - "per_device_train_batch_size": 4, - "per_device_eval_batch_size": 4, - "tokenizer_type": "AutoTokenizer", - "learning_rate": 1e-04, - "min_lr": 1e-5, + "per_device_train_batch_size": 2, + "per_device_eval_batch_size": 2, + "learning_rate": 5e-5, + "min_lr": 5e-6, "weight_decay": 0.1, "gradient_accumulation_steps": 1, "lr_scheduler_type": "cosine", - "num_warmup_steps": 1000, - "num_train_epochs": 6, - "seed": 1234, - "seq_length": 4096, + "num_warmup_steps": 300, + "num_train_epochs": 4, "resume_from_checkpoint": null, - "log_interval": 50, - "checkpointing_steps": 1000, - "evalation_steps": 1000, + "log_interval": 10, + "checkpointing_steps": 100, + "evaluation_steps": 100, "max_train_steps": null, "epoch_checkpointing": true, "shuffle_before_split": true, - "use_random_sampler": true, "early_stopping": true, "early_stopping_stall_num": 5, - "weight_by_num_documents": true, - "make_vocab_size_divisible_by": 128, - "model_parallel_size": 1, - "use_slow_tokenizer": false, - "use_xformers": true, - "trust_remote_code": true, - "use_dynamic_padding": true, - "world_size": 128 + "saving_limit": null } \ No newline at end of file diff --git a/mft_peft_hf/src/pefts/configs/llama_train_config.json b/mftcoder_accelerate/src/configs/qlora_train_config.json similarity index 60% rename from mft_peft_hf/src/pefts/configs/llama_train_config.json rename to mftcoder_accelerate/src/configs/qlora_train_config.json index 2c64a43..851b48d 100644 --- a/mft_peft_hf/src/pefts/configs/llama_train_config.json +++ b/mftcoder_accelerate/src/configs/qlora_train_config.json @@ -1,49 +1,41 @@ { - "load_raw_dataset": true, "data_paths": "$DATA_PATHS", "output_dir": "$OUTPUT_DIR", "tb_dir": "$TensorBoard_DIR", "pretrained_model_path": "$MODEL_NAME_OR_PATH", - "vocab_file": "$MODEL_NAME_OR_PATH", - "low_cpu_mem_usage": true, - "data_split": "95,5,0", - "padding_mode": "pack", + "model_type": "$MODEL_TYPE", + "load_raw_dataset": true, + "data_split": "98,2,0", + "padding_mode": "padding", + "use_dynamic_padding": true, "tokenize_mode": "sft", + "tokenizer_type": "AutoTokenizer", "weighted_loss_mode": "case3", - "model_type": "llama", + "attn_implementation": "flash_attention_2", + "seq_length": 4096, + "seed": 1234, "peft_type": "qlora", "quantization": "4bit", - "lora_rank": 32, + "lora_rank": 96, "lora_alpha": 32, "lora_dropout": 0.05, "per_device_train_batch_size": 2, "per_device_eval_batch_size": 2, - "tokenizer_type": "AutoTokenizer", - "learning_rate": 1e-04, - "min_lr": 1e-5, + "learning_rate": 5e-5, + "min_lr": 5e-6, "weight_decay": 0.1, "gradient_accumulation_steps": 1, "lr_scheduler_type": "cosine", "num_warmup_steps": 300, - "num_train_epochs": 8, - "seed": 1234, - "seq_length": 4096, + "num_train_epochs": 4, "resume_from_checkpoint": null, "log_interval": 10, - "checkpointing_steps": 1000, - "evalation_steps": 1000, + "checkpointing_steps": 100, + "evaluation_steps": 100, "max_train_steps": null, "epoch_checkpointing": true, "shuffle_before_split": true, - "use_random_sampler": true, "early_stopping": true, "early_stopping_stall_num": 5, - "weight_by_num_documents": true, - "make_vocab_size_divisible_by": 128, - "model_parallel_size": 1, - "use_slow_tokenizer": false, - "use_xformers": true, - "trust_remote_code": true, - "use_dynamic_padding": true, - "world_size": 128 + "saving_limit": null } \ No newline at end of file diff --git a/mft_atorch/data/Makefile b/mftcoder_accelerate/src/data/Makefile similarity index 100% rename from mft_atorch/data/Makefile rename to mftcoder_accelerate/src/data/Makefile diff --git a/mft_atorch/data/__init__.py b/mftcoder_accelerate/src/data/__init__.py similarity index 100% rename from mft_atorch/data/__init__.py rename to mftcoder_accelerate/src/data/__init__.py diff --git a/mftcoder_accelerate/src/data/blendable_dataset.py b/mftcoder_accelerate/src/data/blendable_dataset.py new file mode 100644 index 0000000..84b9756 --- /dev/null +++ b/mftcoder_accelerate/src/data/blendable_dataset.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time + +import numpy as np +import torch + +from utils.common_utils import print_rank_0 + + +class BlendableDataset(torch.utils.data.Dataset): + def __init__(self, datasets, weights): + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = 0 + for dataset in self.datasets: + self.size += len(dataset) + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # recompute weights + weights = self.calc_weights() + + # Build indices. + start_time = time.time() + assert num_datasets < 255 + self.dataset_index = np.zeros(self.size, dtype=np.uint8) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from data import helpers + + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + + print( + "> RANK {} elapsed time for building blendable dataset indices: " + "{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time) + ) + + def calc_weights(self): + dataset_sample_cnt = [len(ds) for ds in self.datasets] + total_cnt = sum(dataset_sample_cnt) + weights = np.array([(cnt + 0.0) / total_cnt for cnt in dataset_sample_cnt], dtype=np.float64) + return weights + + def __len__(self): + return self.size + + def __getitem__(self, idx): + try: + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] + except IndexError: + new_idx = idx % len(self) + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) + return self[new_idx] diff --git a/mftcoder_accelerate/src/data/data_utils.py b/mftcoder_accelerate/src/data/data_utils.py new file mode 100644 index 0000000..8d168bd --- /dev/null +++ b/mftcoder_accelerate/src/data/data_utils.py @@ -0,0 +1,420 @@ +# Copyright (c) 2021, EleutherAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import math +import torch +import numpy as np +from typing import List, Tuple +from functools import partial + +from utils.common_utils import print_rank_0, TASK2ID, ID2TASK +from data.indexed_dataset import make_dataset as make_indexed_dataset +from data.blendable_dataset import BlendableDataset +from data.gpt2_dataset import GPT2Dataset, GPT2PromptDataset + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(" > building dataset index ...") + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + print_rank_0(" > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) + print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +def build_train_valid_test_datasets( + data_prefix, + use_shared_fs, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + build_index_mappings=True, + shuffle_before_split=False, + weighted_loss_mode=None, + ds_weights=[1.0, 1.0, 1.0], + train_mode="sft", +): + """Build train, valid, and test datasets.""" + + # Indexed dataset. + assert os.path.exists( + data_prefix + "_input_ids.bin" + ), f"Input tokens datafile not found: {data_prefix}_input_ids.bin" + + # Indexed dataset. + input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup) + if train_mode == "sft": + loss_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_loss_mask", data_impl, skip_warmup) + else: + print(f"pretrain mode, loss mask is ones") + loss_mask_indexed_dataset = None + + total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(" > dataset split:") + + def print_split_stats(name, index): + print_rank_0(" {}:".format(name)) + print_rank_0( + " document indices in [{}, {}) total of {} " + "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) + + # shuffle index before build_dataset + shuffle_doc_index = [] + if shuffle_before_split: + total_num_docs = splits[-1] - splits[0] + shuffle_doc_index = np.arange(start=0, stop=total_num_docs, step=1, dtype=np.uint32) + np_rng = np.random.RandomState(seed=seed) + np_rng.shuffle(shuffle_doc_index) + + def build_dataset(index, name, ds_weight=1.0): + dataset = None + if splits[index + 1] > splits[index]: + if shuffle_before_split: + documents = shuffle_doc_index[splits[index] : splits[index + 1]] + else: + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) + + dataset = GPT2PromptDataset( + name, + data_prefix, + documents, + input_ids_indexed_dataset, + loss_mask_indexed_dataset, + train_valid_test_num_samples[index], + seq_length, + seed, + build_index_mappings=build_index_mappings, + use_shared_fs=use_shared_fs, + weighted_loss_mode=weighted_loss_mode, + ds_weight=ds_weight, + train_mode=train_mode, + ) + return dataset + + train_dataset = build_dataset(0, "train", ds_weights[0]) + valid_dataset = build_dataset(1, "valid", ds_weights[1]) + test_dataset = build_dataset(2, "test", ds_weights[2]) + + return train_dataset, valid_dataset, test_dataset, total_num_of_documents + + +def build_multiple_train_valid_test_datasets( + args, train_valid_test_num_samples, use_shared_fs=True, data_impl="mmap", mmap_warmup=False +): + """Build multiple train, valid, and test datasets.""" + data_prefixes = list(args.data_paths[1:-1].split(",")) + + data_weights = list(map(float, args.data_weights[1:-1].split(","))) + print("data weights: ") + print(data_weights) + use_shared_fs = use_shared_fs + data_impl = data_impl + splits_string = args.data_split + seq_length = args.seq_length + # seq_length = args.block_size + seed = args.seed + skip_warmup = not mmap_warmup + weight_by_num_documents = args.weight_by_num_documents + shuffle_before_split = args.shuffle_before_split + weighted_loss_mode = args.weighted_loss_mode + + weights, weighted_train_valid_test_num_samples = get_datasets_normalized_weights_and_num_samples( + data_weights, train_valid_test_num_samples + ) + + train_weights, valid_weights, test_weights = weights, weights, weights + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(data_prefixes)): + train_ds, valid_ds, test_ds, _ = build_train_valid_test_datasets( + data_prefixes[i], + use_shared_fs, + data_impl, + splits_string, + weighted_train_valid_test_num_samples[i], + seq_length, + seed, + skip_warmup, + build_index_mappings=not weight_by_num_documents, + shuffle_before_split=shuffle_before_split, + weighted_loss_mode=weighted_loss_mode, + train_mode=args.tokenize_mode, + ) + if train_ds is not None: + train_datasets.append(train_ds) + if valid_ds is not None: + valid_datasets.append(valid_ds) + if test_ds is not None: + test_datasets.append(test_ds) + + factor = 1 + if weight_by_num_documents: + # gets the number of documents in each data path + get_num_docs_list = lambda datasets: [dataset.input_ids_indexed_dataset.sizes.shape[0] for dataset in datasets] + train_num_docs, valid_num_docs, test_num_docs = ( + get_num_docs_list(train_datasets), + get_num_docs_list(valid_datasets), + get_num_docs_list(test_datasets), + ) + + # builds weights according to the number of docs + fn = partial(weights_by_num_docs_sft) + train_weights, valid_weights, test_weights = ( + fn(train_num_docs), + fn(valid_num_docs), + fn(test_num_docs), + ) + assert sum(train_weights) != 0.0, "found train weights to be 0.0" + assert sum(valid_weights) != 0.0, "found valid weights to be 0.0" + + train_weights, train_num_samples = get_normalized_weights_and_num_samples( + train_weights, train_valid_test_num_samples[0] + ) + print_rank_0(f"> train sample weights: {train_weights}") + valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( + valid_weights, train_valid_test_num_samples[1] + ) + print_rank_0(f"> valid sample weights: {valid_weights}") + if sum(test_weights) == 0.0: + test_weights = weights # use original weights + test_weights, test_num_samples = get_normalized_weights_and_num_samples( + test_weights, train_valid_test_num_samples[2] + ) + + # weighted loss + num_tokens = [] + ds_fn = partial(ds_weights_by_num_docs_sft) + train_ds_weights, valid_ds_weights, test_ds_weights = ( + ds_fn(train_num_docs), + ds_fn(valid_num_docs), + ds_fn(test_num_docs), + ) + + assert sum(train_ds_weights) != 0.0, "found train loss weights to be 0.0" + assert sum(valid_ds_weights) != 0.0, "found valid loss weights to be 0.0" + + if sum(test_ds_weights) == 0.0: + test_ds_weights = weights # use original weights + print_rank_0(f"> train loss weights: {train_ds_weights}") + print_rank_0(f"> valid loss weights: {valid_ds_weights}") + + train_datasets = [] + valid_datasets = [] + test_datasets = [] + total_sample_cnt = [] + for i in range(len(data_prefixes)): + train_ds, valid_ds, test_ds, total_num_of_documents = build_train_valid_test_datasets( + data_prefixes[i], + use_shared_fs, + data_impl, + splits_string, + [train_num_samples[i], valid_num_samples[i], test_num_samples[i]], + seq_length, + seed, + skip_warmup, + build_index_mappings=True, + shuffle_before_split=shuffle_before_split, + weighted_loss_mode=weighted_loss_mode, + ds_weights=[train_ds_weights[i], valid_ds_weights[i], test_ds_weights[i]], + train_mode=args.tokenize_mode, + ) + total_sample_cnt.append(total_num_of_documents) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # calcualte common factor based on token cnt and total sample cnt + if num_tokens: + factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length) + factor /= sum([1.0 / w for w in train_ds_weights]) / len(train_ds_weights) + + print_rank_0(f"> common denomination factor for CE loss: {factor}") + + # Blend. + blending_train_dataset = None + if train_datasets: + i = 0 + for ds in train_datasets: + ds.update_ds_weight(ds.ds_weight / factor) + print(f"loss weight of dataset {i} after update: {ds.ds_weight}") + i += 1 + blending_train_dataset = BlendableDataset(train_datasets, train_weights) + blending_valid_dataset = None + if valid_datasets: + for ds in valid_datasets: + ds.update_ds_weight(ds.ds_weight / factor) + blending_valid_dataset = BlendableDataset(valid_datasets, valid_weights) + blending_test_dataset = None + if test_datasets: + for ds in test_datasets: + ds.update_ds_weight(ds.ds_weight / factor) + blending_test_dataset = BlendableDataset(test_datasets, test_weights) + + return blending_train_dataset, blending_valid_dataset, blending_test_dataset + + +def get_train_valid_test_split_(splits_string, size): + """Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + + +def get_normalized_weights_and_num_samples(weights: List[float], num_samples: int) -> Tuple[List[float], List[int]]: + # Normalize weights + weight_sum = sum(weights) + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + # Add 0.5% (the 1.005 factor) so in case the blending dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + weighted_num_samples = [] + for weight in weights: + weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) + return weights, weighted_num_samples + + +def get_datasets_normalized_weights_and_num_samples( + weights: List[float], num_samples: List[int] +) -> Tuple[List[float], List[List[int]]]: + # Normalize weights + weight_sum = sum(weights) + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + # Add 0.5% (the 1.005 factor) so in case the blending dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + weighted_num_samples = [] + for weight in weights: + weighted_num_samples.append([int(math.ceil(val * weight * 1.005)) for val in num_samples]) + return weights, weighted_num_samples + + +def ds_weights_by_num_docs_sft(l, alpha=0.3): + # ignore alpha + weights = [1 / i for i in l] + weights_sum = sum(weights) + weights = [weight / weights_sum for weight in weights] + return weights + + +def weights_by_num_docs_sft(l, alpha=0.3): + # ignore alpha + total_n_docs = sum(l) + unbiased_sample_probs = [i / total_n_docs for i in l] + + return unbiased_sample_probs + + +def weights_by_num_docs(l: list, alpha=0.3): + """ + Builds weights from a multinomial distribution over groups of data according to the number of + samples in each group. + + We sample from a group according to the probability p(L) ∝ |L| ** α, + where p(L) is the probability of sampling from a given group, + |L| is the number of examples in that datapoint, + and α is a coefficient that acts to upsample data from underrepresented groups + + Hence α (`alpha`) allows us to control how much to 'boost' the probability of training on low-resource groups. + + See https://arxiv.org/abs/1911.02116 for more details + """ + if len(l) == 1: + return [1.0] + + total_n_docs = sum(l) + unbiased_sample_probs = [i / total_n_docs for i in l] + + probs = [i**alpha for i in unbiased_sample_probs] + + # normalize + total = sum(probs) + probs = [i / total for i in probs] + + # weights should be the inverse of the number of samples + unbiased_sample_probs_inverse = [1 - p for p in unbiased_sample_probs] + weights = [p * p2 for p, p2 in zip(probs, unbiased_sample_probs_inverse)] + + # normalize + total = sum(weights) + weights = [i / total for i in weights] + + return weights + + +def load_dataset_from_bin(args): + """XXX""" + + print_rank_0("> building train, validation, and test datasets ...") + + # Number of train/valid/test samples. + train_iters = 2 + valid_iters = 2 + test_iters = 2 + train_val_test_num_samples = [ + train_iters * 10, + valid_iters * 10, + test_iters * 10, + ] + + # multiple data paths for SFT task + train_ds, valid_ds, test_ds = build_multiple_train_valid_test_datasets( + args=args, + train_valid_test_num_samples=train_val_test_num_samples, + ) + + return train_ds, valid_ds, test_ds diff --git a/mftcoder_accelerate/src/data/gpt2_dataset.py b/mftcoder_accelerate/src/data/gpt2_dataset.py new file mode 100644 index 0000000..12eeb87 --- /dev/null +++ b/mftcoder_accelerate/src/data/gpt2_dataset.py @@ -0,0 +1,456 @@ +# Copyright (c) 2021, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT2 style dataset.""" + +import os +import time + +import numpy as np +import torch + +from utils.common_utils import print_rank_0, TASK2ID, ID2TASK + + +class GPT2PromptDataset(torch.utils.data.Dataset): + def __init__( + self, + name, + data_prefix, + documents, + input_ids_indexed_dataset, + loss_mask_indexed_dataset, + num_samples, + seq_length, + seed, + build_index_mappings=True, + use_shared_fs=True, + weighted_loss_mode=None, + ds_weight=1.0, + train_mode="sft", + ): + + self.name = name + self.input_ids_indexed_dataset = input_ids_indexed_dataset + self.loss_mask_indexed_dataset = loss_mask_indexed_dataset + + self.weighted_loss_mode = weighted_loss_mode + self.ds_weight = ds_weight + + self.task_name = data_prefix.split("/")[-1] + + self.task_id = TASK2ID[self.task_name] + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < self.input_ids_indexed_dataset.sizes.shape[0] + + # print("sequence length is {}".format(seq_length)) + if build_index_mappings: + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, + data_prefix, + documents, + self.input_ids_indexed_dataset.sizes, + num_samples, + seq_length, + seed, + use_shared_fs=use_shared_fs, + ) + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) + + def update_ds_weight(self, weight): + self.ds_weight = weight + + def __len__(self): + return min(self.shuffle_idx_len, self.sample_idx_len) + + def __getitem__(self, idx): + try: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + input_ids = self.input_ids_indexed_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + + if self.loss_mask_indexed_dataset is not None: + loss_mask = self.loss_mask_indexed_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + else: + loss_mask = None + + else: + # Otherwise, get the rest of the initial document. + input_ids_list = [self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + + if self.loss_mask_indexed_dataset is not None: + loss_mask_list = [self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + else: + loss_mask_list = [] + + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + input_ids_list.append(self.input_ids_indexed_dataset.get(self.doc_idx[i])) + if self.loss_mask_indexed_dataset is not None: + loss_mask_list.append(self.loss_mask_indexed_dataset.get(self.doc_idx[i])) + + # And finally add the relevant portion of last document. + input_ids_list.append( + self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + + if self.loss_mask_indexed_dataset is not None: + loss_mask_list.append( + self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + + input_ids = np.concatenate(input_ids_list) + if self.loss_mask_indexed_dataset is not None: # sst do not need to read loss_mask.bin + loss_mask = np.concatenate(loss_mask_list) + + if self.weighted_loss_mode: + # res = { + # "input_ids": np.array(input_ids, dtype=np.float32), + # "loss_mask": np.array(loss_mask, dtype=np.float32), + # "weight": np.array([self.ds_weight], dtype=np.float32), + # } + if self.loss_mask_indexed_dataset is not None: + res = { + "input_ids": np.array(input_ids, dtype=np.float32), + "loss_mask": np.array(loss_mask, dtype=np.float32), + "weight": np.array([self.ds_weight], dtype=np.float32), + "task_id": np.array([self.task_id], dtype=np.int), + } + else: + res = { + "input_ids": np.array(input_ids, dtype=np.float32), + "loss_mask": np.ones_like(input_ids, dtype=np.int64), + "weight": np.array([self.ds_weight], dtype=np.float32), + "task_id": np.array([self.task_id], dtype=np.int), + } + else: + res = { + "input_ids": np.array(input_ids, dtype=np.int64), + "loss_mask": np.array(loss_mask, dtype=np.int64), + "task_id": np.array([self.task_id], dtype=np.int), + } + return res + except IndexError: + new_idx = idx % len(self) + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) + return self[new_idx] + + +class GPT2Dataset(torch.utils.data.Dataset): + def __init__( + self, + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed, + build_index_mappings=True, + use_shared_fs=True, + ): + + self.name = name + self.indexed_dataset = indexed_dataset + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + if build_index_mappings: + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + num_samples, + seq_length, + seed, + use_shared_fs=use_shared_fs, + ) + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) + + def __len__(self): + return min(self.shuffle_idx_len, self.sample_idx_len) + + def __getitem__(self, idx): + try: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + else: + # Otherwise, get the rest of the initial document. + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) + sample = np.concatenate(sample_list) + + return {"text": np.array(sample, dtype=np.int64)} + except IndexError: + new_idx = idx % len(self) + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) + return self[new_idx] + + +def _build_index_mappings( + name, + data_prefix, + documents, + sizes, + num_samples, + seq_length, + seed, + use_shared_fs=True, +): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + if not use_shared_fs: + should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 + else: + should_process_dataset = torch.distributed.get_rank() == 0 + + # Build the indexed mapping if not exist. + if should_process_dataset: + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + print_rank_0(" > WARNING: could not find index map files, building " "the indices on rank 0 ...") + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + from data import helpers + + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + + # 2.0这里做了更新:num_samples = num_epochs * token_per_epoch + # 而1.0版本是:train_batch_size * train_iter * weighted_rate + # 我理解这里的num_samples应该是和入参的num_samples重名,这里只是为了计算构建所有索引的长度,从而决定是用int64还是int32 + num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length + if 2 * (num_samples + 1) < np.iinfo(np.int32).max: + sample_idx = helpers.build_sample_idx_int32(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) + else: + sample_idx = helpers.build_sample_idx_int64(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + + torch.distributed.barrier() # TODO: model parallel + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + # torch.distributed.all_reduce(counts) + # torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) + # assert counts[0].item() == torch.distributed.get_world_size( + # group=mpu.get_io_parallel_group() + # ) + + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)) + print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence length, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng): + """Build an array with length = number-of-epochs * number-of-documents. + Each index is mapped to a corresponding document.""" + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Beginning offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - 1 + remaining_seq_length = 0 + else: + # Otherwise, start from the beginning of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(size, np_rng): + """Build the range [0, size) and shuffle.""" + dtype_ = np.uint32 + if size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx) + return shuffle_idx diff --git a/mft_atorch/data/helpers.cpp b/mftcoder_accelerate/src/data/helpers.cpp similarity index 100% rename from mft_atorch/data/helpers.cpp rename to mftcoder_accelerate/src/data/helpers.cpp diff --git a/mftcoder_accelerate/src/data/indexed_dataset.py b/mftcoder_accelerate/src/data/indexed_dataset.py new file mode 100644 index 0000000..12ea9c2 --- /dev/null +++ b/mftcoder_accelerate/src/data/indexed_dataset.py @@ -0,0 +1,576 @@ +# Copyright (c) 2021, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +import os +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + +import numpy as np +import torch + +from utils.common_utils import print_rank_0 + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == "mmap": + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == "infer": + impl = infer_dataset_impl(path) + if impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == "mmap": + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float32, + 7: np.float64, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, np_array): + assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype + bytes = self.out_file.write(np_array) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in np_array.shape: + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape)) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack(" 0: + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self, archive_name=None): + # TODO: streaming + cctx = zstandard.ZstdCompressor(level=3) + + if archive_name is None: + archive_name = str(int(time.time())) + + res = b"".join( + map(lambda x: ("%016d" % len(x)).encode("UTF-8") + x, map(lambda x: x.encode("UTF-8"), self.data)) + ) + cdata = cctx.compress(res) + + with open(self.out_dir + "/data_" + str(self.i) + "_" + archive_name + ".dat.zst", "wb") as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] + + +class JSONArchive: + def __init__(self, out_dir): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.data = [] + self.i = 0 + if os.path.exists(out_dir) and len(os.listdir(out_dir)) > 0: + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self): + cctx = zstandard.ZstdCompressor(level=3) + + cdata = cctx.compress(json.dumps(self.data).encode("UTF-8")) + with open(self.out_dir + "/data_" + str(self.i) + "_" + str(int(time.time())) + ".json.zst", "wb") as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] diff --git a/mft_peft_hf/src/data/gpt2_multi_task_dataset.py b/mftcoder_accelerate/src/data/multi_task_dataset.py similarity index 63% rename from mft_peft_hf/src/data/gpt2_multi_task_dataset.py rename to mftcoder_accelerate/src/data/multi_task_dataset.py index 4bb3c66..63c4b27 100644 --- a/mft_peft_hf/src/data/gpt2_multi_task_dataset.py +++ b/mftcoder_accelerate/src/data/multi_task_dataset.py @@ -2,16 +2,19 @@ # @author Chaoyu Chen # @date 2023/8/18 +Load dataset in a distributed way. """ + import os import json import math import time +import glob import numpy as np import torch from functools import partial -from data.tokenization.preprocess_data import UniformEncoder -from utils.common_utils import TASK2ID, ID2TASK +from data.preprocess_data import UniformEncoder +from utils.common_utils import TASK2ID class GPT2FromRawDataset(torch.utils.data.Dataset): @@ -27,12 +30,12 @@ def __init__( self.name = name self.input_dataset = input_dataset - self.num_samples = len(self.input_dataset['input_ids']) + self.num_samples = len(self.input_dataset["input_ids"]) self.seq_length = seq_length self.weighted_loss_mode = weighted_loss_mode self.ds_weight = ds_weight - self.task_name = data_prefix.split('/')[-1] + self.task_name = data_prefix.split("/")[-1] self.task_id = TASK2ID[self.task_name] # Checks @@ -47,8 +50,7 @@ def __getitem__(self, idx): try: # Get the shuffled index. idx = idx % self.num_samples - idx_data = {key: self.input_dataset[key][idx] - for key in self.input_dataset} + idx_data = {key: self.input_dataset[key][idx] for key in self.input_dataset} if self.weighted_loss_mode: idx_data["weight"] = np.array([self.ds_weight], dtype=np.float32) @@ -115,9 +117,7 @@ def __init__(self, datasets, weights, global_num_samples, local_num_samples): print( "> RANK {} elapsed time for building blendable dataset indices: " - "{:.2f} (sec)".format( - torch.distributed.get_rank(), time.time() - start_time - ) + "{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time) ) def calc_weights(self): @@ -146,32 +146,28 @@ def __getitem__(self, idx): def shuffle_arrays(arrays, set_seed=-1): - """Shuffles arrays in-place, in the same order, along axis=0 + """Shuffles arrays in-place, in the same order, along axis=0 - Parameters: - ----------- - arrays : List of NumPy arrays. - set_seed : Seed value if int >= 0, else seed is random. - """ - assert all(len(arr) == len(arrays[0]) for arr in arrays) - seed = np.random.randint(0, 2**(32 - 1) - 1) if set_seed < 0 else set_seed + Parameters: + ----------- + arrays : List of NumPy arrays. + set_seed : Seed value if int >= 0, else seed is random. + """ + assert all(len(arr) == len(arrays[0]) for arr in arrays) + seed = np.random.randint(0, 2 ** (32 - 1) - 1) if set_seed < 0 else set_seed - for arr in arrays: - rstate = np.random.RandomState(seed) - rstate.shuffle(arr) + for arr in arrays: + rstate = np.random.RandomState(seed) + rstate.shuffle(arr) def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, local_rank=0): - # tokenization编码器 encoder = UniformEncoder(args, args.tokenize_mode) encoder.initializer() - data_prefixes = list(args.data_paths[1:-1].split(',')) - - # data_weights = list(map(float, args.data_weights[1:-1].split(','))) - # print("data weights: ") - # print(data_weights) + data_prefixes = list(args.data_paths[1:-1].split(",")) + splits = [] splits_string = args.data_split if splits_string.find(",") != -1: @@ -183,7 +179,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, while len(splits) < 3: splits.append(0.0) splits = splits[:3] - print(f'data splits: {splits}') + print(f"data splits: {splits}") all_train_datasets = [] all_valid_datasets = [] @@ -199,46 +195,50 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, # 不同数据集在不同文件夹下 for dataset_index in range(len(data_prefixes)): - files = os.listdir(data_prefixes[dataset_index]) + # files = os.listdir(data_prefixes[dataset_index]) + # get all jsonl files and corresponding reading handler + if data_prefixes[dataset_index].endswith(".jsonl"): + files = [data_prefixes[dataset_index]] + else: + files = glob.glob(os.path.join(data_prefixes[dataset_index], "**/*.jsonl"), recursive=True) + cur_dataset_input_ids = [] cur_dataset_loss_mask = [] # support multiple jsonl files under task dir - for file in files: - file_name = data_prefixes[dataset_index] + '/' + file - if os.path.isdir(file_name): - continue - fin = open(file_name, 'r') - print(f'[Global Rank {global_rank}] open file {file_name}') - - if args.padding_mode == 'padding' or args.padding_mode == 'pack': + for file_name in files: + fin = open(file_name, "r") + print(f"[Global Rank {global_rank}] open file {file_name}") + + if args.padding_mode == "padding" or args.padding_mode == "pack" or args.padding_mode == "concat": for i, line in enumerate(fin): # pre-sharding if shard_data and i % world_size != global_rank: continue - data = json.loads(line.rstrip('\n\r')) - features, length = encoder.encode(data) + data = json.loads(line.rstrip("\n\r")) + features, length = encoder.encode(data, verbose=(i < 1)) + # features, length = encoder.encode(data) # may have more samples - for idx in range(len(features['input_ids'])): - cur_dataset_input_ids.append(features['input_ids'][idx]) - cur_dataset_loss_mask.append(features['loss_mask'][idx]) - + for idx in range(len(features["input_ids"])): + cur_dataset_input_ids.append(features["input_ids"][idx]) + cur_dataset_loss_mask.append(features["loss_mask"][idx]) + fin.close() else: i = 0 for line in fin: - data = json.loads(line.rstrip('\n\r')) + data = json.loads(line.rstrip("\n\r")) features, length = encoder.encode(data) # 一个document可能编码不出sample,可能编码出多个sample - for idx in range(len(features['input_ids'])): + for idx in range(len(features["input_ids"])): # post-sharding if shard_data and i % world_size != global_rank: i += 1 continue i += 1 - cur_dataset_input_ids.append(features['input_ids'][idx]) - cur_dataset_loss_mask.append(features['loss_mask'][idx]) + cur_dataset_input_ids.append(features["input_ids"][idx]) + cur_dataset_loss_mask.append(features["loss_mask"][idx]) fin.close() - + cur_dataset_input_ids = np.array(cur_dataset_input_ids, dtype=np.float32) cur_dataset_loss_mask = np.array(cur_dataset_loss_mask, dtype=np.float32) cur_dataset_num_tokens = np.sum(cur_dataset_loss_mask, dtype=np.int32) @@ -246,60 +246,57 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, num_tokens.append(cur_dataset_num_tokens) total_sample_cnt.append(cur_dataset_sample_num) effective_token_rate.append(cur_dataset_num_tokens / (cur_dataset_sample_num * args.seq_length)) - + # shuffle before split shuffle_arrays([cur_dataset_input_ids, cur_dataset_loss_mask], args.seed) train_ratio = splits[0] / 100.0 train_num = int(math.ceil(train_ratio * cur_dataset_sample_num)) # split train/valid - cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[: train_num], cur_dataset_input_ids[train_num: ] - cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[: train_num], cur_dataset_loss_mask[train_num: ] + cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[:train_num], cur_dataset_input_ids[train_num:] + cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[:train_num], cur_dataset_loss_mask[train_num:] local_train_num += train_num - local_valid_num += (cur_dataset_sample_num - train_num) - - cur_train_dataset = {'input_ids': cur_train_input_ids, - 'loss_mask': cur_train_loss_mask - } - cur_valid_dataset = {'input_ids': cur_valid_input_ids, - 'loss_mask': cur_valid_loss_mask - } + local_valid_num += cur_dataset_sample_num - train_num + + cur_train_dataset = {"input_ids": cur_train_input_ids, "loss_mask": cur_train_loss_mask} + cur_valid_dataset = {"input_ids": cur_valid_input_ids, "loss_mask": cur_valid_loss_mask} print(f"[Global Rank {global_rank}]shape of cur train dataset: {cur_train_dataset['input_ids'].shape}") - print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") + if local_valid_num > 0: + print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") cur_train_ds = GPT2FromRawDataset( - 'train', + "train", data_prefixes[dataset_index], cur_train_dataset, args.seq_length, weighted_loss_mode=args.weighted_loss_mode, - ds_weight=splits[0] - ) - cur_valid_ds = GPT2FromRawDataset( - 'valid', - data_prefixes[dataset_index], - cur_valid_dataset, - args.seq_length, - weighted_loss_mode=args.weighted_loss_mode, - ds_weight=splits[1] + ds_weight=splits[0], ) - all_train_datasets.append(cur_train_ds) - all_valid_datasets.append(cur_valid_ds) all_train_datasets_length.append(len(cur_train_ds)) - all_valid_datasets_length.append(len(cur_valid_ds)) - - print(f'[Global Rank {global_rank}]num tokens: {num_tokens}') - print(f'[Global Rank {global_rank}]effective token rate: {effective_token_rate}') + if local_valid_num > 0: + cur_valid_ds = GPT2FromRawDataset( + "valid", + data_prefixes[dataset_index], + cur_valid_dataset, + args.seq_length, + weighted_loss_mode=args.weighted_loss_mode, + ds_weight=splits[1], + ) + all_valid_datasets.append(cur_valid_ds) + all_valid_datasets_length.append(len(cur_valid_ds)) + else: + cur_valid_ds = None + + print(f"[Global Rank {global_rank}]num tokens: {num_tokens}") + print(f"[Global Rank {global_rank}]effective token rate: {effective_token_rate}") num_tokens = [] ds_fn = partial(ds_weights_by_num_docs_sft) - train_loss_weights, valid_loss_weights = ( - ds_fn(all_train_datasets_length), - ds_fn(all_valid_datasets_length), - ) - + train_loss_weights = ds_fn(all_train_datasets_length) print(f"> train loss weights in rank {global_rank}: {train_loss_weights}") - print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") + if all_valid_datasets_length: + valid_loss_weights = ds_fn(all_valid_datasets_length) + print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") factor = 1 # calcualte common factor based on token cnt and total sample cnt @@ -307,51 +304,65 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length) factor /= sum([1.0 / w for w in train_loss_weights]) / len(train_loss_weights) print(f"> common denomination factor for CE loss in rank {global_rank}: {factor}") - + train_sample_weights = [x / sum(all_train_datasets_length) for x in all_train_datasets_length] - valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] print(f"> train sample weights in rank {global_rank}: {train_sample_weights}") - print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") + if all_valid_datasets_length: + valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] + print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") # recompute global_train_num and global_valid_num - + torch.distributed.barrier() device = f"cuda:{local_rank}" - + global_train_num_samples_tensor = torch.tensor(local_train_num, dtype=torch.int32) global_train_num_samples_tensor = global_train_num_samples_tensor.to(device) torch.distributed.all_reduce(global_train_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) global_train_num = global_train_num_samples_tensor.item() - - global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) - global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) - torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) - global_valid_num = global_valid_num_samples_tensor.item() print(f"> global train num in rank {global_rank}: {global_train_num}") - print(f"> global valid num in rank {global_rank}: {global_valid_num}") - + + if local_valid_num > 0: + global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) + global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) + torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) + global_valid_num = global_valid_num_samples_tensor.item() + print(f"> global valid num in rank {global_rank}: {global_valid_num}") + torch.distributed.barrier() - for i in range(len(all_train_datasets)): - print(f'loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}') blending_train_dataset = None if all_train_datasets: + for i in range(len(all_train_datasets)): + print( + f"loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}" + ) args.do_train = True for i in range(len(all_train_datasets)): all_train_datasets[i].update_ds_weight(train_loss_weights[i] / factor) - print(f'loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}') - blending_train_dataset = GPT2BlendableDataset(all_train_datasets, train_sample_weights, global_train_num, local_train_num) - - for i in range(len(all_train_datasets)): - print(f'loss weight of valid dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}') + print( + f"loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}" + ) + blending_train_dataset = GPT2BlendableDataset( + all_train_datasets, train_sample_weights, global_train_num, local_train_num + ) + blending_valid_dataset = None if all_valid_datasets: + for i in range(len(all_valid_datasets)): + print( + f"loss weight of valid dataset {i} before update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" + ) args.do_valid = True for i in range(len(all_valid_datasets)): all_valid_datasets[i].update_ds_weight(valid_loss_weights[i] / factor) - print(f'loss weight of valid dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}') - blending_valid_dataset = GPT2BlendableDataset(all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num) - + print( + f"loss weight of valid dataset {i} after update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" + ) + blending_valid_dataset = GPT2BlendableDataset( + all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num + ) + return blending_train_dataset, blending_valid_dataset @@ -360,9 +371,13 @@ def compile_helper(): is invoked on a single process.""" import os import subprocess + path = os.path.abspath(os.path.dirname(__file__)) ret = subprocess.run(["make", "-C", path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys + sys.exit(1) + else: + print("Making C++ dataset helpers module successfully.") diff --git a/mftcoder_accelerate/src/data/preprocess_data.py b/mftcoder_accelerate/src/data/preprocess_data.py new file mode 100644 index 0000000..f7226bd --- /dev/null +++ b/mftcoder_accelerate/src/data/preprocess_data.py @@ -0,0 +1,356 @@ +""" +# @author Chaoyu Chen +# @date 2023/9/13 +Preprocessing data and tokenization. +""" + +import os +import sys +import ftfy +import glob + +# print("In preprocess_data_new.py, sys path:", sys.path) + +from tokenizer import build_tokenizer + +CHAT_COL = "chat_rounds" +ROLE_COL = "role" +CONTENT_COL = "content" + +SYSTEM_COL = "system" +PROMPT_COL = "prompt" +ANSWER_COL = "answer" + +TEXT_COL = "text" + +table = {ord(f): ord(t) for f, t in zip(",。!?:【】()%#@&1234567890", ",.!?:[]()%#@&1234567890")} + + +def content_format(content: str): + # Replace non-breaking space with space + content = content.replace("\u202f", " ").replace("\xa0", " ") + + # change chinese punctuation to english ones + # text = text.translate(table) + # if not content.endswith("\n"): + content += "\n" + + return content + + +def is_text_format(data): + if "text" in data: + return True + else: + return False + + +def is_chatml_format(data): + if "chat_rounds" in data and len(data["chat_rounds"]) > 0: + return True + else: + return False + + +def is_prompt_answer_format(data): + if "prompt" in data and "answer" in data: + return True + else: + return False + + +def is_prompt_response_format(data): + if "prompt" in data and "response" in data: + return True + else: + return False + + +def is_input_output_format(data): + if "input" in data and "output" in data: + return True + else: + return False + + +def is_instruction_output_format(data): + if "instruction" in data and "output" in data: + return True + else: + return False + + +def is_instruction_response_format(data): + if "instruction" in data and "response" in data: + return True + else: + return False + + +def is_question_response_format(data): + if "question" in data and "response" in data: + return True + else: + return False + + +def is_question_answer_format(data): + if "question" in data and "answer" in data: + return True + else: + return False + + +def is_query_answer_format(data): + if "query" in data and "answer" in data: + return True + else: + return False + + +class Encoder(object): + tokenizer = None + + def __init__(self, args): + self.args = args + + def initializer(self): + # Use Encoder class as a container for global data + Encoder.tokenizer = build_tokenizer(self.args) + # self.tokenizer = build_tokenizer(self.args) + + def pure_encode(self, content): + return Encoder.tokenizer.encode(content, add_special_tokens=False) + + def encode(self, text): + if self.args.ftfy: + text = ftfy.fix_text(text) + ids = {} + for key in self.args.jsonl_keys: + doc_ids = [] + text_ids = self.pure_encode(text) + if len(text_ids) > 0: + doc_ids.append(text_ids) + if self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eos_token_id) + ids[key] = doc_ids + return ids, len(text) + + +class UniformEncoder(Encoder): + def __init__(self, args, mode="sft"): + super().__init__(args) + self.verbose = False + self.mode = mode + # seq_length + 1 for shifting + if args.load_raw_dataset: + self.seq_length = args.seq_length + 1 + self.stride = args.seq_length + else: + self.seq_length = args.seq_length + + self.remain_input_ids = [] + self.remain_loss_mask = [] + + def encode(self, data, verbose=False): + self.verbose = verbose + encode_res = {"input_ids": [], "loss_mask": []} + + if is_prompt_answer_format(data): + data_type = "prompt_answer" + elif is_prompt_response_format(data): + data_type = "prompt_response" + elif is_input_output_format(data): + data_type = "input_output" + elif is_instruction_output_format(data): + data_type = "instruction_output" + elif is_instruction_response_format(data): + data_type = "instruction_response" + elif is_question_response_format(data): + data_type = "question_response" + elif is_question_answer_format(data): + data_type = "question_answer" + elif is_query_answer_format(data): + data_type = "query_answer" + elif is_chatml_format(data): + data_type = "chatML" + elif is_text_format(data): + data_type = "text" + else: + raise ValueError( + f"data_type does not support" + f"please use chatML or prompt/answer, prompt/response, question/response, " + f"instruction/output, input/output, instruction/output or text(only for pretrain)" + ) + + length = 0 + if data_type == "chatML": + for chat in data["chat_rounds"]: + length += len(chat["content"]) + elif data_type == "text": + length += len(data["text"]) + else: + # update key + global PROMPT_COL, ANSWER_COL + PROMPT_COL, ANSWER_COL = tuple(data_type.split("_")) + length = len(data[PROMPT_COL]) + len(data[ANSWER_COL]) + + for token_res in self._tokenize_fields(data, data_type=data_type): + for k, v in token_res.items(): + encode_res[k].append(v) + + return encode_res, length + + def _tokenize_fields(self, data, data_type): + if self.mode == "sft": + if self.args.role_markers: + system_marker = self.args.role_markers["system"] + user_marker = self.args.role_markers["user"] + assistant_marker = self.args.role_markers["assistant"] + else: + system_marker = "system\n" + user_marker = "human\n" + assistant_marker = "bot\n" + elif self.mode == "pretrain": + system_marker = "" + user_marker = "" + assistant_marker = "" + else: + raise ValueError(f"tokenize_mode does not support {self.mode}, please use sft or pretrain") + + sft_end_marker_ids = [Encoder.tokenizer.eos_token_id] + # uniform SST,SFT,MFT + input_ids = [] + loss_mask = [] + + if data_type == "chatML": + chat = data[CHAT_COL] + if chat[0][ROLE_COL] == "system": + sys_content_ids = self.pure_encode(system_marker + content_format(chat[0][CONTENT_COL])) + chat = chat[1:] + input_ids += sys_content_ids + loss_mask += [0] * len(sys_content_ids) + + for i, r in enumerate(chat): + role = r[ROLE_COL] + content = r[CONTENT_COL] + content = content_format(content) + if (role == "human" or role == "user") != (i % 2 == 0): + raise ValueError( + "Conversation roles must alternate user/assistant/user/assistant/... or human/bot/human/bot/...')" + ) + + # compute loss only for assistant's content and eos token afterward + if role == "human" or role == "user": + content_ids = self.pure_encode(user_marker + content + assistant_marker) + input_ids += content_ids + loss_mask += [0] * len(content_ids) + elif role == "bot" or role == "assistant" or role == "gpt": + content_ids = self.pure_encode(content) + sft_end_marker_ids + input_ids += content_ids + loss_mask += [1] * len(content_ids) + extra_ids = self.pure_encode("\n") + input_ids += extra_ids + loss_mask += [0] * len(extra_ids) + else: + raise ValueError(f"Role {role} not supported.") + + elif data_type == "text": + text = data[TEXT_COL] + text = content_format(text) + text_ids = self.pure_encode(text) + sft_end_marker_ids + input_ids += text_ids + loss_mask += [1] * len(text_ids) + else: + system = data.get(SYSTEM_COL, "") + prompt = data[PROMPT_COL] + answer = data[ANSWER_COL] + + system = content_format(system_marker + system) if system else "" + prompt = content_format(prompt) + answer = content_format(answer) + + prompt_ids = self.pure_encode(system + user_marker + prompt + assistant_marker) + answer_ids = self.pure_encode(answer) + sft_end_marker_ids + + input_ids += prompt_ids + answer_ids + loss_mask += [0] * len(prompt_ids) + [1] * len(answer_ids) + + # print(self.mode) + if self.mode == "pretrain": + # change loss mask to all 1s + input_ids = input_ids + loss_mask = [1] * len(loss_mask) + elif self.mode == "sft": + # do nothing + input_ids = input_ids + loss_mask = loss_mask + + if self.verbose: + print(f"original data:\n{data}") + print(f"decoding back:\n{Encoder.tokenizer.decode(input_ids)}") + + assert len(input_ids) == len(loss_mask) + if self.args.padding_mode == "padding": + if len(input_ids) <= self.seq_length: + yield self.padding(input_ids, loss_mask) + + # drop if too long + else: + yield {} + elif self.args.padding_mode == "concat": + input_ids = self.remain_input_ids + input_ids + loss_mask = self.remain_loss_mask + loss_mask + if len(input_ids) < self.seq_length: + self.remain_input_ids = input_ids + self.remain_loss_mask = loss_mask + assert len(self.remain_input_ids) == len(self.remain_loss_mask) + yield {} + else: + cursor = 0 + while cursor + self.seq_length <= len(input_ids): + yield { + "input_ids": input_ids[cursor : cursor + self.seq_length], + "loss_mask": loss_mask[cursor : cursor + self.seq_length], + } + cursor = cursor + self.stride + self.remain_input_ids = input_ids[cursor:] + self.remain_loss_mask = loss_mask[cursor:] + assert len(self.remain_input_ids) == len(self.remain_loss_mask) + yield {} + elif self.args.padding_mode == "pack": + if len(input_ids) > self.seq_length: + yield {} + elif len(self.remain_input_ids) + len(input_ids) > self.seq_length: + input_ids, self.remain_input_ids = self.remain_input_ids, input_ids + loss_mask, self.remain_loss_mask = self.remain_loss_mask, loss_mask + assert len(input_ids) == len(loss_mask) + yield self.padding(input_ids, loss_mask) + else: + self.remain_input_ids = self.remain_input_ids + input_ids + self.remain_loss_mask = self.remain_loss_mask + loss_mask + assert len(self.remain_input_ids) == len(self.remain_loss_mask) + yield {} + + def padding(self, input_ids, loss_mask): + pad_id = Encoder.tokenizer.pad_token_id + assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}" + input_ids += [pad_id] * (self.seq_length - len(input_ids)) + loss_mask += [0] * (self.seq_length - len(loss_mask)) + return {"input_ids": input_ids, "loss_mask": loss_mask} + + +def find_jsonl_fnames(paths): + fnames = [] + for p in paths: + if not os.path.isdir(p): + if p.endswith(".jsonl"): + print(f"loading from {p}") + fnames.append(p) + else: + p_list = glob.glob(p + "/*") + for p_ in p_list: + if p_.endswith(".jsonl"): + print(f"loading from {p_}") + fnames.append(p_) + return fnames diff --git a/mftcoder_accelerate/src/ds_multinode_launch.sh b/mftcoder_accelerate/src/ds_multinode_launch.sh new file mode 100755 index 0000000..dca0670 --- /dev/null +++ b/mftcoder_accelerate/src/ds_multinode_launch.sh @@ -0,0 +1,44 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2024/5/20 +# Description: # Launch script on Multiple Nodes + +# Run this script on all Nodes. + +# You need to export your number of nodes and number of GPUs per node first. +N_NODE=4 +N_GPU_PER_NODE=8 + +# You need to export $MACHINE_RANK, $MASTER_ADDR, $MASTER_PORT automatically for each Node. + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines $N_NODE \ + --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \ + --use_deepspeed \ + --deepspeed_multinode_launcher 'standard' \ + --zero_stage 2 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'none' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag false \ + --zero3_save_16bit_model false \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank $MACHINE_RANK \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" --distributed_type "deepspeed" \ No newline at end of file diff --git a/mftcoder_accelerate/src/ds_single_launch.sh b/mftcoder_accelerate/src/ds_single_launch.sh new file mode 100755 index 0000000..d6c84bb --- /dev/null +++ b/mftcoder_accelerate/src/ds_single_launch.sh @@ -0,0 +1,38 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2023/12/11 +# Description: An alternative(Command line) way to launch DeepSpeed training + +# Launch script on single node +N_GPU_PER_NODE=8 + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines 1 \ + --num_processes $N_GPU_PER_NODE \ + --use_deepspeed \ + --zero_stage 2 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'none' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag false \ + --zero3_save_16bit_model false \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank 0 \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ + --distributed_type "deepspeed" \ + > MFTCoder-training-"$TODAY".log 2>&1 & diff --git a/mftcoder_accelerate/src/ds_zero3_single_launch.sh b/mftcoder_accelerate/src/ds_zero3_single_launch.sh new file mode 100755 index 0000000..5f581c9 --- /dev/null +++ b/mftcoder_accelerate/src/ds_zero3_single_launch.sh @@ -0,0 +1,38 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2024/5/20 +# Description: An alternative(Command line) way to launch DeepSpeed training + +# Launch script on single node +N_GPU_PER_NODE=8 + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines 1 \ + --num_processes $N_GPU_PER_NODE \ + --use_deepspeed \ + --zero_stage 3 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'cpu' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag true \ + --zero3_save_16bit_model true \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank 0 \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ + --distributed_type "deepspeed" \ + > MFTCoder-training-"$TODAY".log 2>&1 & diff --git a/mftcoder_accelerate/src/fsdp_single_launch.sh b/mftcoder_accelerate/src/fsdp_single_launch.sh new file mode 100755 index 0000000..2dc8f89 --- /dev/null +++ b/mftcoder_accelerate/src/fsdp_single_launch.sh @@ -0,0 +1,43 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2023/12/11 +# Description: An alternative(command line) way to launch FSDP training + +# Launch script on single node +N_GPU_PER_NODE=8 + +# config path +CONFIG="configs/xxx_train_config.json" + +# fsdp_transformer_layer_cls_to_wrap, choose the DecoderLayer +WRAP_MODULE="LlamaDecoderLayer" + + + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_fsdp_config.yaml \ +accelerate launch \ + --use_fsdp \ + --num_machines=1 \ + --num_processes=$N_GPU_PER_NODE \ + --fsdp_sharding_strategy=1 \ + --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \ + --fsdp_state_dict_type=FULL_STATE_DICT \ + --fsdp_backward_prefetch_policy=BACKWARD_PRE \ + --fsdp_transformer_layer_cls_to_wrap=$WRAP_MODULE \ + --fsdp_offload_params=false \ + --main_training_function=main \ + --mixed_precision=bf16 \ + --dynamo_backend=no \ + --same_network \ + --machine_rank=0 \ + --rdzv_backend=static \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ + --distributed_type "fsdp" \ + > MFTCoder-training-"$TODAY".log 2>&1 & + diff --git a/mft_peft_hf/src/model/__init__.py b/mftcoder_accelerate/src/model/__init__.py similarity index 100% rename from mft_peft_hf/src/model/__init__.py rename to mftcoder_accelerate/src/model/__init__.py diff --git a/mft_peft_hf/src/model/llama2/configuration_llama.py b/mftcoder_accelerate/src/model/aquila2/configuration_aquila.py similarity index 53% rename from mft_peft_hf/src/model/llama2/configuration_llama.py rename to mftcoder_accelerate/src/model/aquila2/configuration_aquila.py index 83132bd..62097dd 100644 --- a/mft_peft_hf/src/model/llama2/configuration_llama.py +++ b/mftcoder_accelerate/src/model/aquila2/configuration_aquila.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -17,22 +17,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +""" Aquila model configuration""" -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging +from transformers import PretrainedConfig -logger = logging.get_logger(__name__) -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): +class AquilaConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + This is the configuration class to store the configuration of a [`AquilaModel`]. It is used to instantiate an Aquila model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. + defaults will yield a similar configuration to that of the Aquila-7B. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -40,8 +35,8 @@ class LlamaConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] + Vocabulary size of the Aquila model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AquilaModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): @@ -50,19 +45,6 @@ class LlamaConfig(PretrainedConfig): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): @@ -77,35 +59,26 @@ class LlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - Example: ```python - >>> from transformers import LlamaModel, LlamaConfig + >>> from transformers import AquilaModel, AquilaConfig - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() + >>> # Initializing a Aquila aquila-7b style configuration + >>> configuration = AquilaConfig() - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) + >>> # Initializing a model from the aquila-7b style configuration + >>> model = AquilaModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" - model_type = "llama" + model_type = "aquila" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, - vocab_size=32000, + vocab_size=100008, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, @@ -121,6 +94,7 @@ def __init__( eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, + rope_theta=10000.0, rope_scaling=None, use_xformers=True, **kwargs, @@ -130,21 +104,22 @@ def __init__( self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_xformers = use_xformers # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads + + self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache + self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self._rope_scaling_validation() + self.use_xformers = use_xformers super().__init__( pad_token_id=pad_token_id, @@ -154,23 +129,3 @@ def __init__( **kwargs, ) - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/mft_peft_hf/src/model/llama2/modeling_llama.py b/mftcoder_accelerate/src/model/aquila2/modeling_aquila.py similarity index 78% rename from mft_peft_hf/src/model/llama2/modeling_llama.py rename to mftcoder_accelerate/src/model/aquila2/modeling_aquila.py index 6aa560d..d0d9309 100644 --- a/mft_peft_hf/src/model/llama2/modeling_llama.py +++ b/mftcoder_accelerate/src/model/aquila2/modeling_aquila.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -17,12 +17,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMA model.""" +""" PyTorch Aquila model.""" import math from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -31,13 +30,23 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig +from .configuration_aquila import AquilaConfig +from transformers import ( + LogitsProcessorList, + MinLengthLogitsProcessor, + TopKLogitsWarper, + TemperatureLogitsWarper, + TopPLogitsWarper, + StoppingCriteriaList, + MaxLengthCriteria, + BitsAndBytesConfig, +) import xformers.ops logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "LlamaConfig" +_CONFIG_FOR_DOC = "AquilaConfig" # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -48,7 +57,7 @@ def _make_causal_mask( Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -73,10 +82,11 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -class LlamaRMSNorm(nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Aquila +class AquilaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - LlamaRMSNorm is equivalent to T5LayerNorm + AquilaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -84,13 +94,14 @@ def __init__(self, hidden_size, eps=1e-6): def forward(self, hidden_states): input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + + return (self.weight * hidden_states).to(input_dtype) -class LlamaRotaryEmbedding(torch.nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Aquila +class AquilaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -98,7 +109,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( @@ -125,9 +136,9 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Aquila +class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding): + """AquilaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor @@ -144,9 +155,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Aquila +class AquilaDynamicNTKScalingRotaryEmbedding(AquilaRotaryEmbedding): + """AquilaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor @@ -160,7 +171,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -189,10 +200,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed -class LlamaMLP(nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Aquila +class AquilaMLP(nn.Module): def __init__(self, config): super().__init__() - self.pretraining_tp = config.pretraining_tp + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -201,17 +213,21 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) up_proj_slices = self.up_proj.weight.split(slice, dim=0) down_proj_slices = self.down_proj.weight.split(slice, dim=1) - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] down_proj = sum(down_proj) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -231,10 +247,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class LlamaAttention(nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Aquila +class AquilaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): + def __init__(self, config: AquilaConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -242,8 +258,8 @@ def __init__(self, config: LlamaConfig): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta self.use_xformers = config.use_xformers if (self.head_dim * self.num_heads) != self.hidden_size: @@ -259,17 +275,27 @@ def __init__(self, config: LlamaConfig): def _init_rope(self): if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self.rotary_emb = AquilaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.rotary_emb = AquilaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.rotary_emb = AquilaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -288,19 +314,21 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: @@ -334,12 +362,13 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask(), - op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) attn_output = attn_output.contiguous().view(bsz, q_len, -1) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - + attn_weights = torch.clamp(attn_weights, min=-1024., max=1024.) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" @@ -366,10 +395,10 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) @@ -379,14 +408,15 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Aquila +class AquilaDecoderLayer(nn.Module): + def __init__(self, config: AquilaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = AquilaAttention(config=config) + self.mlp = AquilaMLP(config) + self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -442,8 +472,7 @@ def forward( return outputs - -LLAMA_START_DOCSTRING = r""" +AQUILA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -453,7 +482,7 @@ def forward( and behavior. Parameters: - config ([`LlamaConfig`]): + config ([`AquilaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. @@ -461,14 +490,15 @@ def forward( @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare Aquila Model outputting raw hidden-states without any specific head on top.", + AQUILA_START_DOCSTRING, ) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Aquila +class AquilaPreTrainedModel(PreTrainedModel): + config_class = AquilaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] + _no_split_modules = ["AquilaDecoderLayer"] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): @@ -483,11 +513,11 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): + if isinstance(module, AquilaModel): module.gradient_checkpointing = value -LLAMA_INPUTS_DOCSTRING = r""" +AQUILA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -552,25 +582,26 @@ def _set_gradient_checkpointing(self, module, value=False): @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare Aquila Model outputting raw hidden-states without any specific head on top.", + AQUILA_START_DOCSTRING, ) -class LlamaModel(LlamaPreTrainedModel): +# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->AQUILA,Llama->Aquila +class AquilaModel(AquilaPreTrainedModel): """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AquilaDecoderLayer`] Args: - config: LlamaConfig + config: AquilaConfig """ - def __init__(self, config: LlamaConfig): + def __init__(self, config: AquilaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers = nn.ModuleList([AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -582,7 +613,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -606,7 +636,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em return combined_attention_mask - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -689,7 +719,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, output_attentions, None) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -698,7 +728,6 @@ def custom_forward(*inputs): hidden_states, attention_mask, position_ids, - None, ) else: layer_outputs = decoder_layer( @@ -734,14 +763,13 @@ def custom_forward(*inputs): attentions=all_self_attns, ) - -class LlamaForCausalLM(LlamaPreTrainedModel): +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->AQUILA,Llama->Aquila +class AquilaForCausalLM(AquilaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp + self.model = AquilaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -766,7 +794,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -793,18 +821,18 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from transformers import AutoTokenizer, AquilaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> model = AquilaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -827,9 +855,9 @@ def forward( ) hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) @@ -899,12 +927,118 @@ def _reorder_cache(past_key_values, beam_idx): ) return reordered_past + def predict(self, text, tokenizer=None, + max_gen_len=200, top_p=0.95, + seed=1234, topk=100, + temperature=0.9, + sft=True, convo_template = "aquila-chat", + device = "cuda"): + + vocab = tokenizer.get_vocab() + #device = device + id2word = {v:k for k, v in vocab.items()} + + + set_random_seed(seed) + if temperature == 0: + topk = 1 + temperature = 1.0 + if sft: + tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=convo_template) + tokens = torch.tensor(tokens)[None,].to(device) + else : + tokens = tokenizer.encode_plus(text)["input_ids"] + print(tokenizer.decode(tokens)) + tokens = torch.tensor(tokens)[None,].to(device) + input_length = len(tokens[0]) + with torch.no_grad(): + + # instantiate logits processors + logits_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(1, eos_token_id=100007), + ] + ) + # instantiate logits processors + logits_warper = LogitsProcessorList( + [ + TopPLogitsWarper(top_p), + TopKLogitsWarper(topk), + TemperatureLogitsWarper(temperature), + + ] + ) + + stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=input_length + max_gen_len)]) + out = self.sample( + tokens, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + return_dict_in_generate=True, + output_scores=True, + ) + + + # print(out) + out_ids = out["sequences"][0][input_length:].cpu().numpy() + + out_scores = out["scores"] + + out_scores = torch.cat(out_scores, dim=0) + out_scores = torch.nn.functional.softmax(out_scores, dim=-1).cpu().numpy() + + probs = [] + for i in range(len(out_ids)): + probs.append(float(out_scores[i][out_ids[i]])) + + # print(f"probs is {probs}") + + convert_tokens = [] + for t in out_ids: + if t == 100006: + convert_tokens.append("[CLS]") + else : + convert_tokens.append(id2word.get(t, "[unkonwn_token]")) + + out_text = tokenizer.decode(out_ids.tolist()) + + + out = out_text + + if "###" in out: + special_index = out.index("###") + out = out[: special_index] + token_length = len(tokenizer.encode_plus(out)["input_ids"]) + convert_tokens = convert_tokens[:token_length] + probs = probs[:token_length] + + if "[UNK]" in out: + special_index = out.index("[UNK]") + out = out[:special_index] + token_length = len(tokenizer.encode_plus(out)["input_ids"]) + convert_tokens = convert_tokens[:token_length] + probs = probs[:token_length] + + if "" in out: + special_index = out.index("") + out = out[: special_index] + token_length = len(tokenizer.encode_plus(out)["input_ids"]) + convert_tokens = convert_tokens[:token_length] + probs = probs[:token_length] + + if len(out) > 0 and out[0] == " ": + out = out[1:] + + convert_tokens = convert_tokens[1:] + probs = probs[1:] + return out @add_start_docstrings( """ The LLaMa Model transformer with a sequence classification head on top (linear layer). - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + [`AquilaForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a @@ -913,13 +1047,16 @@ def _reorder_cache(past_key_values, beam_idx): padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - LLAMA_START_DOCSTRING, + AQUILA_START_DOCSTRING, ) -class LlamaForSequenceClassification(LlamaPreTrainedModel): +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->AQUILA,Llama->Aquila +class AquilaForSequenceClassification(AquilaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = LlamaModel(config) + self.model = AquilaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing @@ -931,7 +1068,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -978,7 +1115,9 @@ def forward( sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( + logits.device + ) else: sequence_lengths = -1 @@ -1018,3 +1157,4 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + diff --git a/mft_peft_hf/src/model/baichuan/configuration_baichuan.py b/mftcoder_accelerate/src/model/baichuan2/configuration_baichuan.py similarity index 91% rename from mft_peft_hf/src/model/baichuan/configuration_baichuan.py rename to mftcoder_accelerate/src/model/baichuan2/configuration_baichuan.py index 9da5dfc..592d274 100644 --- a/mft_peft_hf/src/model/baichuan/configuration_baichuan.py +++ b/mftcoder_accelerate/src/model/baichuan2/configuration_baichuan.py @@ -2,6 +2,7 @@ from transformers.configuration_utils import PretrainedConfig + class BaichuanConfig(PretrainedConfig): model_type = "baichuan" keys_to_ignore_at_inference = ["past_key_values"] @@ -23,7 +24,7 @@ def __init__( eos_token_id=2, tie_word_embeddings=False, gradient_checkpointing=False, - use_xformers=False, + z_loss_weight=0, **kwargs, ): self.vocab_size = vocab_size @@ -36,8 +37,8 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache - self.gradient_checkpointing = gradient_checkpointing - self.use_xformers = use_xformers + self.z_loss_weight = z_loss_weight + self.gradient_checkpointing = (gradient_checkpointing,) super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -45,4 +46,3 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) - diff --git a/mftcoder_accelerate/src/model/baichuan2/generation_utils.py b/mftcoder_accelerate/src/model/baichuan2/generation_utils.py new file mode 100644 index 0000000..5771699 --- /dev/null +++ b/mftcoder_accelerate/src/model/baichuan2/generation_utils.py @@ -0,0 +1,83 @@ +from typing import List +from queue import Queue + +import torch + + +def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0): + def _parse_messages(messages, split_role="user"): + system, rounds = "", [] + round = [] + for i, message in enumerate(messages): + if message["role"] == "system": + assert i == 0 + system = message["content"] + continue + if message["role"] == split_role and round: + rounds.append(round) + round = [] + round.append(message) + if round: + rounds.append(round) + return system, rounds + + max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens + max_input_tokens = model.config.model_max_length - max_new_tokens + system, rounds = _parse_messages(messages, split_role="user") + system_tokens = tokenizer.encode(system) + max_history_tokens = max_input_tokens - len(system_tokens) + + history_tokens = [] + for round in rounds[::-1]: + round_tokens = [] + for message in round: + if message["role"] == "user": + round_tokens.append(model.generation_config.user_token_id) + else: + round_tokens.append(model.generation_config.assistant_token_id) + round_tokens.extend(tokenizer.encode(message["content"])) + if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: + history_tokens = round_tokens + history_tokens # concat left + if len(history_tokens) < max_history_tokens: + continue + break + + input_tokens = system_tokens + history_tokens + if messages[-1]["role"] != "assistant": + input_tokens.append(model.generation_config.assistant_token_id) + input_tokens = input_tokens[-max_input_tokens:] # truncate left + return torch.LongTensor([input_tokens]).to(model.device) + + +class TextIterStreamer: + def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt + self.skip_special_tokens = skip_special_tokens + self.tokens = [] + self.text_queue = Queue() + self.next_tokens_are_prompt = True + + def put(self, value): + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + else: + if len(value.shape) > 1: + value = value[0] + self.tokens.extend(value.tolist()) + self.text_queue.put( + self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) + + def end(self): + self.text_queue.put(None) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get() + if value is None: + raise StopIteration() + else: + return value + diff --git a/mftcoder_accelerate/src/model/baichuan2/modeling_baichuan.py b/mftcoder_accelerate/src/model/baichuan2/modeling_baichuan.py new file mode 100644 index 0000000..9f9875b --- /dev/null +++ b/mftcoder_accelerate/src/model/baichuan2/modeling_baichuan.py @@ -0,0 +1,827 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +from .configuration_baichuan import BaichuanConfig +from .generation_utils import build_chat_input, TextIterStreamer + +import math +from threading import Thread +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F +from transformers import PreTrainedModel, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.generation.utils import GenerationConfig +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging, ContextManagers + +import os +from contextlib import contextmanager +from accelerate import init_empty_weights + +logger = logging.get_logger(__name__) + +try: + from xformers import ops as xops +except ImportError: + xops = None + logger.warning( + "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers." + ) + + +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def _fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): + _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1) + _future_mask = _future_mask.unsqueeze(0) + alibi + new_future_mask = _future_mask.to(tensor) + return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] + + +def _gen_alibi_mask(tensor, n_head, max_pos): + slopes = torch.Tensor(_get_interleave(n_head)) + position_point = torch.arange(max_pos) - max_pos + 1 + position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + alibi = alibi.view(n_head, 1, max_pos) + alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) + alibi_mask = alibi_mask.unsqueeze(0) + alibi + return alibi_mask + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, epsilon=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(hidden_size)) + self.epsilon = epsilon + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + + # convert into half-precision + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class MLP(torch.nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class BaichuanAttention(torch.nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.model_max_length + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" + ) + self.W_pack = torch.nn.Linear( + self.hidden_size, 3 * self.hidden_size, bias=False + ) + self.o_proj = torch.nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = ( + proj.unflatten(-1, (3, self.hidden_size)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + ) + query_states = ( + proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + key_states = ( + proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + value_states = ( + proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + if xops is not None: + attn_weights = None + # query_states = query_states.transpose(1, 2) + # key_states = key_states.transpose(1, 2) + # value_states = value_states.transpose(1, 2) + # attn_output = xops.memory_efficient_attention( + # query_states, key_states, value_states, attn_bias=attention_mask + # ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask) + attn_output = attn_output.transpose(1, 2) + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BaichuanLayer(torch.nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = BaichuanAttention(config=config) + self.mlp = MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, epsilon=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BaichuanPreTrainedModel(PreTrainedModel): + config_class = BaichuanConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BaichuanLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, torch.nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, torch.nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BaichuanModel): + module.gradient_checkpointing = value + + +class BaichuanModel(BaichuanPreTrainedModel): + def __init__(self, config: BaichuanConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.n_head = config.num_attention_heads + self.embed_tokens = torch.nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = torch.nn.ModuleList( + [BaichuanLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + + self.gradient_checkpointing = config.gradient_checkpointing + self.post_init() + self.max_cache_pos = config.model_max_length + self.first_run = True + self.alibi_mask = None + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_alibi_mask(self, tensor, seq_length_with_past): + if self.training: + slopes = torch.Tensor(_get_interleave(self.n_head)) + position_point = ( + torch.arange(seq_length_with_past) - seq_length_with_past + 1 + ) + position_point = ( + position_point.unsqueeze(0) + .unsqueeze(0) + .expand(self.n_head, seq_length_with_past, -1) + ) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( + -1, -2 + ) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + mask = _buffered_future_mask( + tensor, seq_length_with_past, alibi, self.n_head + ) + else: + if self.first_run: + self.first_run = False + self.register_buffer( + "future_mask", + _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( + tensor + ), + persistent=False, + ) + if seq_length_with_past > self.max_cache_pos: + self.max_cache_pos = seq_length_with_past + self.register_buffer( + "future_mask", + _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( + tensor + ), + persistent=False, + ) + mask = self.future_mask[ + : self.n_head, :seq_length_with_past, :seq_length_with_past + ] + return mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot provide both input_ids and inputs_embeds simultaneously" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You need to provide input_ids or inputs_embeds") + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + seq_length_with_past = seq_length + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.training: + if ( + self.alibi_mask is None + or self.alibi_mask.shape[-1] != seq_length_with_past + ): + self.alibi_mask = self.get_alibi_mask( + inputs_embeds, seq_length_with_past + ) + alibi_mask = self.alibi_mask + else: + alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) + + if attention_mask is not None: + if len(attention_mask.shape) == 2: + expanded_mask = attention_mask.to(alibi_mask.dtype) + expanded_mask = torch.tril( + torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) + ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) + else: + expanded_mask = attention_mask + bsz = inputs_embeds.size(0) + src_len, tgt_len = alibi_mask.size()[-2:] + expanded_mask = ( + expanded_mask.unsqueeze(1) + .expand(bsz, 1, src_len, tgt_len) + .to(alibi_mask.dtype) + ) + inverted_mask = 1.0 - expanded_mask + inverted_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min + ) + attention_mask = inverted_mask + alibi_mask.unsqueeze(0) + else: + attention_mask = alibi_mask + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class NormHead(nn.Module): + def __init__(self, hidden_size, vocab_size, bias=False): + super().__init__() + self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.first_flag = True + + def forward(self, hidden_states): + if self.training: + norm_weight = nn.functional.normalize(self.weight) + self.first_flag = True + elif self.first_flag: + self.first_flag = False + self.weight.data = nn.functional.normalize(self.weight) + norm_weight = self.weight + else: + norm_weight = self.weight + return nn.functional.linear(hidden_states, norm_weight) + +_init_weights = True +@contextmanager +def no_init_weights(_enable=True): + global _init_weights + old_init_weights = _init_weights + if _enable: + _init_weights = False + try: + yield + finally: + _init_weights = old_init_weights + + +class BaichuanForCausalLM(BaichuanPreTrainedModel): + def __init__(self, config, *model_args, **model_kwargs): + super().__init__(config, *model_args, **model_kwargs) + self.model = BaichuanModel(config) + self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False) + #if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: + if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False): + try: + from .quantizer import quantize_offline, init_model_weight_int4 + except ImportError: + raise ImportError(f"Needs quantize_offline to run quantize.") + quantize_offline(self, 4) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=False, + proxies=None, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder="", + _from_auto=False, + _from_pipeline=None, + **kwargs, + ) + else: + model_kwargs = kwargs + + if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: + try: + from .quantizer import init_model_weight_int4 + from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map + from accelerate.utils import CustomDtype + from accelerate.utils import get_balanced_memory + except ImportError: + raise ImportError(f"Needs import model weight init func to run quantize.") + # Instantiate model. + init_contexts = [no_init_weights(_enable=True)] + init_contexts.append(init_empty_weights()) + with ContextManagers(init_contexts): + model = cls(config) + + model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin') + state_dict = torch.load(model_file, map_location="cpu") + model.is_quantized = True + + device_map = kwargs.pop("device_map", None) + torch_dtype = kwargs.pop("torch_dtype", None) + if device_map is not None: + kwargs = {"no_split_module_classes": model._no_split_modules} + target_dtype = CustomDtype.INT4 + max_memory = get_balanced_memory( + model, + dtype=target_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=None, + **kwargs, + ) + kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs) + model = init_model_weight_int4(config, model, state_dict) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=False, + proxies=None, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder="", + _from_auto=False, + _from_pipeline=None, + **kwargs, + ) + except (OSError, TypeError): + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if device_map is not None: + dispatch_model(model, device_map=device_map) + + return model + + return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args, + config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, + use_safetensors=use_safetensors, **kwargs) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + softmax_normalizer = shift_logits.max(-1).values ** 2 + z_loss = self.config.z_loss_weight * softmax_normalizer.mean() + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + z_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def quantize(self, bits: int): + try: + from .quantizer import quantize_online + except ImportError: + raise ImportError(f"Needs QLinear to run quantize.") + return quantize_online(self, bits) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + return tuple( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) + for layer_past in past_key_values + ) + + def _build_chat_input( + self, tokenizer, messages: List[dict], max_new_tokens: int = 0 + ): + max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens + max_input_tokens = self.config.model_max_length - max_new_tokens + max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) + total_input, round_input = [], [] + for i, message in enumerate(messages[::-1]): + content_tokens = tokenizer.encode(message["content"]) + if message["role"] == "user": + round_input = ( + [self.generation_config.user_token_id] + + content_tokens + + round_input + ) + if ( + total_input + and len(total_input) + len(round_input) > max_input_tokens + ): + break + else: + total_input = round_input + total_input + if len(total_input) >= max_input_tokens: + break + else: + round_input = [] + elif message["role"] == "assistant": + round_input = ( + [self.generation_config.assistant_token_id] + + content_tokens + + [self.generation_config.eos_token_id] + + round_input + ) + else: + raise ValueError(f"message role not supported yet: {message['role']}") + total_input = total_input[-max_input_tokens:] # truncate left + total_input.append(self.generation_config.assistant_token_id) + total_input = torch.LongTensor([total_input]).to(self.device) + return total_input + + def chat(self, tokenizer, messages: List[dict], stream=False, + generation_config: Optional[GenerationConfig]=None): + generation_config = generation_config or self.generation_config + input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens) + if stream: + streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + Thread(target=self.generate, kwargs=dict( + inputs=input_ids, streamer=streamer, + generation_config=generation_config, + )).start() + return streamer + else: + outputs = self.generate(input_ids, generation_config=generation_config) + response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) + return response diff --git a/mftcoder_accelerate/src/model/baichuan2/quantizer.py b/mftcoder_accelerate/src/model/baichuan2/quantizer.py new file mode 100644 index 0000000..232fc74 --- /dev/null +++ b/mftcoder_accelerate/src/model/baichuan2/quantizer.py @@ -0,0 +1,215 @@ +try: + import bitsandbytes as bnb + from bitsandbytes.nn.modules import Params4bit, Int8Params +except ImportError: + print('import bitsandbytes Error') + +from accelerate import init_empty_weights +import torch + +def Params4bitCuda(self, device): + self.data = self.data.cuda(device) + self.quant_state[0] = self.quant_state[0].cuda(device) + self.quant_state[4][0] = self.quant_state[4][0].cuda(device) + self.quant_state[4][1][0] = self.quant_state[4][1][0].cuda(device) + self.quant_state[4][1][1] = self.quant_state[4][1][1].cuda(device) + + self.quant_state[6] = self.quant_state[6].cuda(device) + return self + +class Linear4bitOnline(torch.nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type + ) + self.compute_dtype = None + #self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out + +class Linear8bitLtOnline(torch.nn.Module): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__() + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) + self.bias = bias + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out + +def quantize_offline(model, bits: int): + assert (bits == 4), f'bits: {bits} is not supported' + + for i, layer in enumerate(model.model.layers): + layer.self_attn.W_pack = bnb.nn.Linear4bit( + layer.self_attn.W_pack.weight.shape[1], + layer.self_attn.W_pack.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + layer.self_attn.o_proj = bnb.nn.Linear4bit( + layer.self_attn.o_proj.weight.shape[1], + layer.self_attn.o_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + + layer.mlp.gate_proj = bnb.nn.Linear4bit( + layer.mlp.gate_proj.weight.shape[1], + layer.mlp.gate_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + layer.mlp.down_proj = bnb.nn.Linear4bit( + layer.mlp.down_proj.weight.shape[1], + layer.mlp.down_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + layer.mlp.up_proj = bnb.nn.Linear4bit( + layer.mlp.up_proj.weight.shape[1], + layer.mlp.up_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + return model + +def quantize_online(model, bits: int): + def quant(weight, bias=None): + if bits == 8: + linear = Linear8bitLtOnline( + weight, + bias, + has_fp16_weights=False, + threshold=6.0, + ) + if bias is not None: + linear.bias = torch.nn.Parameter(bias) + elif bits == 4: + linear = Linear4bitOnline( + weight, + bias, + quant_type="nf4", #fp4/nf4 + ) + else: + raise ValueError("quantize only support 4/8 bit") + return linear + + for i, layer in enumerate(model.model.layers): + layer.self_attn.W_pack = quant(layer.self_attn.W_pack.weight) + layer.self_attn.o_proj = quant(layer.self_attn.o_proj.weight) + layer.mlp.gate_proj = quant(layer.mlp.gate_proj.weight) + layer.mlp.down_proj = quant(layer.mlp.down_proj.weight) + layer.mlp.up_proj = quant(layer.mlp.up_proj.weight) + return model + +def init_model_weight_int4(config, model, state_dict): + #replace Params4bit.cuda with Params4bitCuda + Params4bit.cuda = Params4bitCuda + + for i in range(config.num_hidden_layers): + weight_data = state_dict[f'model.layers.{i}.self_attn.W_pack.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.self_attn.W_pack.weight.quant_state'] + model.model.layers[i].self_attn.W_pack.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.self_attn.o_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.self_attn.o_proj.weight.quant_state'] + model.model.layers[i].self_attn.o_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.mlp.gate_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.mlp.gate_proj.weight.quant_state'] + model.model.layers[i].mlp.gate_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.mlp.up_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.mlp.up_proj.weight.quant_state'] + model.model.layers[i].mlp.up_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.mlp.down_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.mlp.down_proj.weight.quant_state'] + model.model.layers[i].mlp.down_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + model.model.layers[i].input_layernorm.weight = state_dict[f'model.layers.{i}.input_layernorm.weight'] + model.model.layers[i].post_attention_layernorm.weight = state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] + + model.model.embed_tokens.weight = state_dict['model.embed_tokens.weight'] + model.model.norm.weight = state_dict['model.norm.weight'] + model.lm_head.weight = state_dict['lm_head.weight'] + return model \ No newline at end of file diff --git a/mft_peft_hf/src/model/baichuan/tokenization_baichuan.py b/mftcoder_accelerate/src/model/baichuan2/tokenization_baichuan.py similarity index 85% rename from mft_peft_hf/src/model/baichuan/tokenization_baichuan.py rename to mftcoder_accelerate/src/model/baichuan2/tokenization_baichuan.py index 1275eb0..978963e 100644 --- a/mft_peft_hf/src/model/baichuan/tokenization_baichuan.py +++ b/mftcoder_accelerate/src/model/baichuan2/tokenization_baichuan.py @@ -48,10 +48,26 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False) + if isinstance(pad_token, str) + else pad_token + ) super().__init__( bos_token=bos_token, eos_token=eos_token, @@ -122,7 +138,9 @@ def convert_tokens_to_string(self, tokens): out_string += self.sp_model.decode(current_sub_tokens) return out_string - def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. @@ -137,10 +155,14 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) logger.error(f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -161,7 +183,10 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding @@ -180,7 +205,9 @@ def get_special_tokens_mask( """ if already_has_special_tokens: return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, ) bos_token_id = [1] if self.add_bos_token else [] @@ -229,4 +256,3 @@ def create_token_type_ids_from_sequences( output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) return output - diff --git a/mft_peft_hf/src/model/chatglm2/config.json b/mftcoder_accelerate/src/model/chatglm2/config.json similarity index 100% rename from mft_peft_hf/src/model/chatglm2/config.json rename to mftcoder_accelerate/src/model/chatglm2/config.json diff --git a/mft_peft_hf/src/model/chatglm2/configuration_chatglm.py b/mftcoder_accelerate/src/model/chatglm2/configuration_chatglm.py similarity index 100% rename from mft_peft_hf/src/model/chatglm2/configuration_chatglm.py rename to mftcoder_accelerate/src/model/chatglm2/configuration_chatglm.py diff --git a/mft_peft_hf/src/model/chatglm2/modeling_chatglm.py b/mftcoder_accelerate/src/model/chatglm2/modeling_chatglm.py similarity index 100% rename from mft_peft_hf/src/model/chatglm2/modeling_chatglm.py rename to mftcoder_accelerate/src/model/chatglm2/modeling_chatglm.py diff --git a/mft_peft_hf/src/model/chatglm2/quantization.py b/mftcoder_accelerate/src/model/chatglm2/quantization.py similarity index 100% rename from mft_peft_hf/src/model/chatglm2/quantization.py rename to mftcoder_accelerate/src/model/chatglm2/quantization.py diff --git a/mft_peft_hf/src/model/chatglm2/tokenization_chatglm.py b/mftcoder_accelerate/src/model/chatglm2/tokenization_chatglm.py similarity index 99% rename from mft_peft_hf/src/model/chatglm2/tokenization_chatglm.py rename to mftcoder_accelerate/src/model/chatglm2/tokenization_chatglm.py index d4ce416..01d8278 100644 --- a/mft_peft_hf/src/model/chatglm2/tokenization_chatglm.py +++ b/mftcoder_accelerate/src/model/chatglm2/tokenization_chatglm.py @@ -66,7 +66,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs): - super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) + self.name = "GLMTokenizer" self.vocab_file = vocab_file @@ -76,6 +76,7 @@ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces "": self.tokenizer.eos_id, "": self.tokenizer.pad_id } + super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) def get_command(self, token): if token in self.special_tokens: diff --git a/mft_peft_hf/src/model/chatglm2/tokenizer_config.json b/mftcoder_accelerate/src/model/chatglm2/tokenizer_config.json similarity index 100% rename from mft_peft_hf/src/model/chatglm2/tokenizer_config.json rename to mftcoder_accelerate/src/model/chatglm2/tokenizer_config.json diff --git a/mftcoder_accelerate/src/model/chatglm3/config.json b/mftcoder_accelerate/src/model/chatglm3/config.json new file mode 100644 index 0000000..37933c8 --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/config.json @@ -0,0 +1,42 @@ +{ + "_name_or_path": "THUDM/chatglm3-6b", + "model_type": "chatglm", + "architectures": [ + "ChatGLMModel" + ], + "auto_map": { + "AutoConfig": "configuration_chatglm.ChatGLMConfig", + "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification" + }, + "add_bias_linear": false, + "add_qkv_bias": true, + "apply_query_key_layer_scaling": true, + "apply_residual_connection_post_layernorm": false, + "attention_dropout": 0.0, + "attention_softmax_in_fp32": true, + "bias_dropout_fusion": true, + "ffn_hidden_size": 13696, + "fp32_residual_connection": false, + "hidden_dropout": 0.0, + "hidden_size": 4096, + "kv_channels": 128, + "layernorm_epsilon": 1e-05, + "multi_query_attention": true, + "multi_query_group_num": 2, + "num_attention_heads": 32, + "num_layers": 28, + "original_rope": true, + "padded_vocab_size": 65024, + "post_layer_norm": true, + "rmsnorm": true, + "seq_length": 8192, + "use_cache": true, + "torch_dtype": "float16", + "transformers_version": "4.30.2", + "tie_word_embeddings": false, + "eos_token_id": 2, + "pad_token_id": 0 +} \ No newline at end of file diff --git a/mftcoder_accelerate/src/model/chatglm3/configuration_chatglm.py b/mftcoder_accelerate/src/model/chatglm3/configuration_chatglm.py new file mode 100644 index 0000000..3560018 --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/configuration_chatglm.py @@ -0,0 +1,61 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + classifier_dropout=None, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.classifier_dropout = classifier_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) \ No newline at end of file diff --git a/mftcoder_accelerate/src/model/chatglm3/modeling_chatglm.py b/mftcoder_accelerate/src/model/chatglm3/modeling_chatglm.py new file mode 100644 index 0000000..e75568b --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/modeling_chatglm.py @@ -0,0 +1,1293 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +from copy import deepcopy + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" +_CONFIG_FOR_DOC = "ChatGLMConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm3-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, **_config_to_kwargs(config) + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, + device=device, **_config_to_kwargs(config) + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, + device=input_ids.device), full_attention_mask), dim=-1) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, + dtype=config.torch_dtype, **init_kwargs) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + is_first_forward: bool = True, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + if past_key_values is not None: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + "use_cache": use_cache + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, output, history): + content = "" + history = deepcopy(history) + for response in output.split("<|assistant|>"): + metadata, content = response.split("\n", maxsplit=1) + if not metadata.strip(): + content = content.strip() + history.append({"role": "assistant", "metadata": metadata, "content": content}) + content = content.replace("[[训练时间]]", "2023年") + else: + history.append({"role": "assistant", "metadata": metadata, "content": content}) + if history[0]["role"] == "system" and "tools" in history[0]: + content = "\n".join(content.split("\n")[1:-1]) + def tool_call(**kwargs): + return kwargs + parameters = eval(content) + content = {"name": metadata.strip(), "parameters": parameters} + else: + content = {"name": metadata.strip(), "content": content} + return content, history + + @torch.inference_mode() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", + max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + inputs = tokenizer.build_chat_input(query, history=history, role=role) + inputs = inputs.to(self.device) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + history.append({"role": role, "content": query}) + response, history = self.process_response(response, history) + return response, history + + @torch.inference_mode() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", + past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, + logits_processor=None, return_past_key_values=False, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if past_key_values is None: + inputs = tokenizer.build_chat_input(query, history=history, role=role) + else: + inputs = tokenizer.build_chat_input(query, role=role) + inputs = inputs.to(self.device) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs['attention_mask'] = attention_mask + history.append({"role": role, "content": query}) + for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, + eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, + **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response, new_history = self.process_response(response, history) + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.inference_mode() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + model_kwargs["use_cache"] = generation_config.use_cache + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, + **kwargs) + return self + + +class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.num_labels = config.num_labels + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + + self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) + if config.classifier_dropout is not None: + self.dropout = nn.Dropout(config.classifier_dropout) + else: + self.dropout = None + self.config = config + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + full_attention_mask=full_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + pooled_hidden_states = hidden_states[-1] + if self.dropout is not None: + pooled_hidden_states = self.dropout(pooled_hidden_states) + logits = self.classifier_head(pooled_hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze().float(), labels.squeeze()) + else: + loss = loss_fct(logits.float(), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/mftcoder_accelerate/src/model/chatglm3/quantization.py b/mftcoder_accelerate/src/model/chatglm3/quantization.py new file mode 100644 index 0000000..cb95bfe --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/quantization.py @@ -0,0 +1,188 @@ +from torch.nn import Linear +from torch.nn.parameter import Parameter + +import bz2 +import torch +import base64 +import ctypes +from transformers.utils import logging + +from typing import List +from functools import partial + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + "int4WeightCompression", + "int4WeightExtractionFloat", + "int4WeightExtractionHalf", + "int8WeightExtractionFloat", + "int8WeightExtractionHalf", + ], + ) +except Exception as exception: + kernels = None + logger.warning("Failed to load cpm_kernels:" + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features,))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): + assert scale_list.dtype in [torch.half, torch.bfloat16] + assert weight.dtype in [torch.int8] + if source_bit_width == 8: + return weight.to(scale_list.dtype) * scale_list[:, None] + elif source_bit_width == 4: + func = ( + kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16 + ) + else: + assert False, "Unsupported bit-width" + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(torch.nn.Module): + def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, + **kwargs): + super().__init__() + self.weight_bit_width = weight_bit_width + + shape = weight.shape + + if weight is None or empty_init: + self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device) + self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device) + else: + self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) + self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(device), requires_grad=False) + self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False) + self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, device=None): + """Replace fp16 linear with quantized linear""" + for layer in model.layers: + layer.self_attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.query_key_value.bias, + dtype=layer.self_attention.query_key_value.weight.dtype, + device=layer.self_attention.query_key_value.weight.device if device is None else device, + empty_init=empty_init + ) + layer.self_attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.dense.bias, + dtype=layer.self_attention.dense.weight.dtype, + device=layer.self_attention.dense.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_h_to_4h.bias, + dtype=layer.mlp.dense_h_to_4h.weight.dtype, + device=layer.mlp.dense_h_to_4h.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_4h_to_h.bias, + dtype=layer.mlp.dense_4h_to_h.weight.dtype, + device=layer.mlp.dense_4h_to_h.weight.device if device is None else device, + empty_init=empty_init + ) + + return model diff --git a/mftcoder_accelerate/src/model/chatglm3/tokenization_chatglm.py b/mftcoder_accelerate/src/model/chatglm3/tokenization_chatglm.py new file mode 100644 index 0000000..e50d329 --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/tokenization_chatglm.py @@ -0,0 +1,283 @@ +import json +import os +import torch +from typing import List, Optional, Union, Dict +from sentencepiece import SentencePieceProcessor +from transformers import PreTrainedTokenizer +from transformers.utils import logging, PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding + + +class SPTokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.unk_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop", "<|system|>", "<|user|>", "<|assistant|>", + "<|observation|>"] + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + + def tokenize(self, s: str): + return self.sp_model.EncodeAsPieces(s) + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + text, buffer = "", [] + for token in t: + if token in self.index_special_tokens: + if buffer: + text += self.sp_model.decode(buffer) + buffer = [] + text += self.index_special_tokens[token] + else: + buffer.append(token) + if buffer: + text += self.sp_model.decode(buffer) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return self.index_special_tokens[index] + if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: + return "" + return self.sp_model.IdToPiece(index) + + +class ChatGLMTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "tokenizer.model"} + + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs): + self.name = "GLMTokenizer" + + self.vocab_file = vocab_file + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + "": self.tokenizer.bos_id, + "": self.tokenizer.eos_id, + "": self.tokenizer.pad_id + } + super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" + return self.tokenizer.special_tokens[token] + + @property + def unk_token(self) -> str: + return "" + + @property + def pad_token(self) -> str: + return "" + + @property + def pad_token_id(self): + return self.get_command("") + + @property + def eos_token(self) -> str: + return "" + + @property + def eos_token_id(self): + return self.get_command("") + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, self.vocab_files_names["vocab_file"] + ) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def get_prefix_tokens(self): + prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")] + return prefix_tokens + + def build_single_message(self, role, metadata, message): + assert role in ["system", "user", "assistant", "observation"], role + role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") + message_tokens = self.tokenizer.encode(message) + tokens = role_tokens + message_tokens + return tokens + + def build_chat_input(self, query, history=None, role="user"): + if history is None: + history = [] + input_ids = [] + for item in history: + content = item["content"] + if item["role"] == "system" and "tools" in item: + content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) + input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) + input_ids.extend(self.build_single_message(role, "", query)) + input_ids.extend([self.get_command("<|assistant|>")]) + return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + prefix_tokens = self.get_prefix_tokens() + token_ids_0 = prefix_tokens + token_ids_0 + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * seq_length + + if "position_ids" not in encoded_inputs: + encoded_inputs["position_ids"] = list(range(seq_length)) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/mft_peft_hf/src/model/code_llama/__init__.py b/mftcoder_accelerate/src/model/code_llama/__init__.py similarity index 51% rename from mft_peft_hf/src/model/code_llama/__init__.py rename to mftcoder_accelerate/src/model/code_llama/__init__.py index cfeaa6b..c2d50fb 100644 --- a/mft_peft_hf/src/model/code_llama/__init__.py +++ b/mftcoder_accelerate/src/model/code_llama/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# Copyright 2023 MetaAI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +13,10 @@ # limitations under the License. from typing import TYPE_CHECKING -from transformers.utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_sentencepiece_available, - is_tokenizers_available, - is_torch_available, -) +from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_tokenizers_available -_import_structure = { - "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], -} +_import_structure = {} try: if not is_sentencepiece_available(): @@ -32,7 +24,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["tokenization_llama"] = ["LlamaTokenizer"] + _import_structure["tokenization_code_llama"] = ["CodeLlamaTokenizer"] try: if not is_tokenizers_available(): @@ -40,32 +32,16 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"] - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_llama"] = [ - "LlamaForCausalLM", - "LlamaModel", - "LlamaPreTrainedModel", - "LlamaForSequenceClassification", - ] - + _import_structure["tokenization_code_llama_fast"] = ["CodeLlamaTokenizerFast"] if TYPE_CHECKING: - from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig - try: if not is_sentencepiece_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: pass else: - from .tokenization_llama import LlamaTokenizer + from .tokenization_code_llama import CodeLlamaTokenizer try: if not is_tokenizers_available(): @@ -73,16 +49,7 @@ except OptionalDependencyNotAvailable: pass else: - from .tokenization_llama_fast import LlamaTokenizerFast - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel - + from .tokenization_code_llama_fast import CodeLlamaTokenizerFast else: import sys diff --git a/mft_peft_hf/src/model/code_llama/configuration_llama.py b/mftcoder_accelerate/src/model/code_llama/configuration_llama.py similarity index 95% rename from mft_peft_hf/src/model/code_llama/configuration_llama.py rename to mftcoder_accelerate/src/model/code_llama/configuration_llama.py index 430bf9f..2c914c8 100644 --- a/mft_peft_hf/src/model/code_llama/configuration_llama.py +++ b/mftcoder_accelerate/src/model/code_llama/configuration_llama.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +"""4.33.1 LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -66,8 +66,8 @@ class LlamaConfig(PretrainedConfig): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-12): @@ -77,6 +77,8 @@ class LlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format @@ -121,6 +123,7 @@ def __init__( eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, + rope_theta=10000.0, rope_scaling=None, use_xformers=True, **kwargs, @@ -133,7 +136,6 @@ def __init__( self.num_attention_heads = num_attention_heads self.use_xformers = use_xformers - # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -144,6 +146,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache + self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() @@ -164,14 +167,14 @@ def _rope_scaling_validation(self): if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/mft_peft_hf/src/model/code_llama/convert_llama_weights_to_hf.py b/mftcoder_accelerate/src/model/code_llama/convert_llama_weights_to_hf.py similarity index 90% rename from mft_peft_hf/src/model/code_llama/convert_llama_weights_to_hf.py rename to mftcoder_accelerate/src/model/code_llama/convert_llama_weights_to_hf.py index 03c1eb4..acc4988 100644 --- a/mft_peft_hf/src/model/code_llama/convert_llama_weights_to_hf.py +++ b/mftcoder_accelerate/src/model/code_llama/convert_llama_weights_to_hf.py @@ -53,18 +53,12 @@ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). """ -INTERMEDIATE_SIZE_MAP = { - "7B": 11008, - "13B": 13824, - "30B": 17920, - "65B": 22016, - "70B": 28672, -} NUM_SHARDS = { "7B": 1, "7Bf": 1, "13B": 2, "13Bf": 2, + "34B": 4, "30B": 4, "65B": 8, "70B": 8, @@ -86,7 +80,11 @@ def write_json(text, path): json.dump(text, f) -def write_model(model_path, input_base_path, model_size, safe_serialization=True): +def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): + # for backward compatibility, before you needed the repo to be called `my_repo/model_size` + if not os.path.isfile(os.path.join(input_base_path, "params.json")): + input_base_path = os.path.join(input_base_path, model_size) + os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") os.makedirs(tmp_model_path, exist_ok=True) @@ -98,8 +96,18 @@ def write_model(model_path, input_base_path, model_size, safe_serialization=True n_heads_per_shard = n_heads // num_shards dim = params["dim"] dims_per_head = dim // n_heads - base = 10000.0 + base = params.get("rope_theta", 10000.0) inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + if base > 10000.0: + max_position_embeddings = 16384 + else: + max_position_embeddings = 2048 + + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + if tokenizer_path is not None: + tokenizer = tokenizer_class(tokenizer_path) + tokenizer.save_pretrained(model_path) + vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 if "n_kv_heads" in params: num_key_value_heads = params["n_kv_heads"] # for GQA / MQA @@ -247,6 +255,9 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): num_hidden_layers=params["n_layers"], rms_norm_eps=params["norm_eps"], num_key_value_heads=num_key_value_heads, + vocab_size=vocab_size, + rope_theta=base, + max_position_embeddings=max_position_embeddings, ) config.save_pretrained(tmp_model_path) @@ -256,10 +267,10 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): gc.collect() print("Loading the checkpoint in a Llama model.") - model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) # Avoid saving this as part of the config. del model.config._name_or_path - + model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path) @@ -281,7 +292,7 @@ def main(): ) parser.add_argument( "--model_size", - choices=["7B", "7Bf", "13B", "13Bf", "30B", "65B", "70B", "70Bf", "tokenizer_only"], + choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", ) parser.add_argument( @@ -290,15 +301,17 @@ def main(): ) parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") args = parser.parse_args() + spm_path = os.path.join(args.input_dir, "tokenizer.model") if args.model_size != "tokenizer_only": write_model( model_path=args.output_dir, - input_base_path=os.path.join(args.input_dir, args.model_size), + input_base_path=args.input_dir, model_size=args.model_size, safe_serialization=args.safe_serialization, + tokenizer_path=spm_path, ) - spm_path = os.path.join(args.input_dir, "tokenizer.model") - write_tokenizer(args.output_dir, spm_path) + else: + write_tokenizer(args.output_dir, spm_path) if __name__ == "__main__": diff --git a/mft_peft_hf/src/model/code_llama/modeling_llama.py b/mftcoder_accelerate/src/model/code_llama/modeling_llama.py similarity index 97% rename from mft_peft_hf/src/model/code_llama/modeling_llama.py rename to mftcoder_accelerate/src/model/code_llama/modeling_llama.py index fde3a9c..7dea342 100644 --- a/mft_peft_hf/src/model/code_llama/modeling_llama.py +++ b/mftcoder_accelerate/src/model/code_llama/modeling_llama.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMA model.""" +"""4.33.1 PyTorch LLaMA model.""" import math from typing import List, Optional, Tuple, Union @@ -247,6 +247,7 @@ def __init__(self, config: LlamaConfig): self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta self.use_xformers = config.use_xformers if (self.head_dim * self.num_heads) != self.hidden_size: @@ -262,17 +263,27 @@ def __init__(self, config: LlamaConfig): def _init_rope(self): if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -339,8 +350,9 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask(), - op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) attn_output = attn_output.contiguous().view(bsz, q_len, -1) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) diff --git a/mft_peft_hf/src/model/llama2/tokenization_llama.py b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama.py similarity index 50% rename from mft_peft_hf/src/model/llama2/tokenization_llama.py rename to mftcoder_accelerate/src/model/code_llama/tokenization_code_llama.py index dd1c936..0c99a02 100644 --- a/mft_peft_hf/src/model/llama2/tokenization_llama.py +++ b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama.py @@ -1,10 +1,6 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved. # -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for LLaMA.""" +"""4.33.1 Tokenization classes for Code LLaMA.""" import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import sentencepiece as spm +from transformers.convert_slow_tokenizer import import_protobuf from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer -from transformers.utils import logging +from transformers.utils import logging, requires_backends if TYPE_CHECKING: @@ -38,14 +35,14 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-code-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-code-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "hf-internal-testing/llama-tokenizer": 2048, + "hf-internal-testing/llama-code-tokenizer": 2048, } SPIECE_UNDERLINE = "▁" @@ -53,46 +50,71 @@ B_SYS, E_SYS = "<>\n", "\n<>\n\n" # fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\ +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ -that your responses are socially unbiased and positive in nature. + that your responses are socially unbiased and positive in nature. -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\ +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ correct. If you don't know the answer to a question, please don't share false information.""" # fmt: on -class LlamaTokenizer(PreTrainedTokenizer): +class CodeLlamaTokenizer(PreTrainedTokenizer): """ - Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is - no padding token in the original model. + Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as + there is no padding token in the original model. + + The default configuration match that of + [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json) + which supports prompt infilling. Args: vocab_file (`str`): Path to the vocabulary file. - legacy (`bool`, *optional*, defaults to `True`): - Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622 - which includes fixes to properly handle tokens that appear after special tokens. A simple example: - - - `legacy=True`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) - >>> tokenizer.encode("Hello .") - [8774, 32099, 3, 5, 1] - ``` - - `legacy=False`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) - >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here - [8774, 32099, 5, 1] - ``` - Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for - more details. - + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + prefix_token (`str`, *optional*, defaults to `"▁
"`):
+            Prefix token used for infilling.
+        suffix_token (`str`, *optional*, defaults to `"▁"`):
+            Suffix token used for infilling.
+        middle_token (`str`, *optional*, defaults to `"▁"`):
+            Middle token used for infilling.
+        eot_token (`str`, *optional*, defaults to `"▁"`):
+            End of text token used for infilling.
+        fill_token (`str`, *optional*, defaults to `""`):
+            The token used to split the input between the prefix and suffix.
+        suffix_first (`bool`, *optional*, default to `False`):
+            Whether the input prompt and suffix should be formatted with the suffix first.
+        additional_special_tokens (`List[str]`, *optional*):
+            Additional special tokens used by the tokenizer.
+        sp_model_kwargs (`dict`, *optional*):
+            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+            to set:
+
+            - `enable_sampling`: Enable subword regularization.
+            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+              - `nbest_size = {0,1}`: No sampling is performed.
+              - `nbest_size > 1`: samples from the nbest_size results.
+              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+                using forward-filtering-and-backward-sampling algorithm.
+
+            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+              BPE-dropout.
+        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+            Whether or not the default system prompt for Llama should be used.
     """
 
     vocab_files_names = VOCAB_FILES_NAMES
@@ -106,99 +128,184 @@ def __init__(
         unk_token="",
         bos_token="",
         eos_token="",
-        pad_token=None,
+        prefix_token="▁
",
+        middle_token="▁",
+        suffix_token="▁",
+        eot_token="▁",
+        fill_token="",
+        suffix_first=False,
         sp_model_kwargs: Optional[Dict[str, Any]] = None,
         add_bos_token=True,
         add_eos_token=False,
         clean_up_tokenization_spaces=False,
-        legacy=True,
+        additional_special_tokens=None,
+        use_default_system_prompt=False,
         **kwargs,
     ):
+        requires_backends(self, "protobuf")
         self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
         bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
         eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
         unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
-        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+
+        self.use_default_system_prompt = use_default_system_prompt
+        # mark tokens special to skip them
+        additional_special_tokens = additional_special_tokens or []
+        for token in [prefix_token, middle_token, suffix_token, eot_token]:
+            additional_special_tokens += [token] if token is not None else []
+
         super().__init__(
             bos_token=bos_token,
             eos_token=eos_token,
             unk_token=unk_token,
-            pad_token=pad_token,
             add_bos_token=add_bos_token,
             add_eos_token=add_eos_token,
+            prefix_token=prefix_token,
+            middle_token=middle_token,
+            suffix_token=suffix_token,
+            eot_token=eot_token,
+            fill_token=fill_token,
             sp_model_kwargs=self.sp_model_kwargs,
+            suffix_first=suffix_first,
             clean_up_tokenization_spaces=clean_up_tokenization_spaces,
-            legacy=legacy,
+            additional_special_tokens=additional_special_tokens,
+            use_default_system_prompt=use_default_system_prompt,
             **kwargs,
         )
-        if legacy:
-            logger.warning_once(
-                f"You are using the legacy behaviour of the {self.__class__}. This means that tokens that come after special tokens will not be properly handled. We recommend you to"
-                " read the related pull request available at https://github.com/huggingface/transformers/pull/24565"
-            )
-        self.legacy = legacy
         self.vocab_file = vocab_file
         self.add_bos_token = add_bos_token
         self.add_eos_token = add_eos_token
-        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
-        self.sp_model.Load(vocab_file)
+        self._prefix_token = prefix_token
+        self._middle_token = middle_token
+        self._suffix_token = suffix_token
+        self._eot_token = eot_token
+        self.fill_token = fill_token
+        self.suffix_first = suffix_first
+        self.sp_model = self.get_spm_processor()
 
-    def __getstate__(self):
-        state = self.__dict__.copy()
-        state["sp_model"] = None
-        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
-        return state
+    @property
+    def unk_token_length(self):
+        return len(self.sp_model.encode(str(self.unk_token)))
+
+    def get_spm_processor(self):
+        tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        with open(self.vocab_file, "rb") as f:
+            sp_model = f.read()
+            model_pb2 = import_protobuf()
+            model = model_pb2.ModelProto.FromString(sp_model)
+            normalizer_spec = model_pb2.NormalizerSpec()
+            normalizer_spec.add_dummy_prefix = False
+            model.normalizer_spec.MergeFrom(normalizer_spec)
+            sp_model = model.SerializeToString()
+            tokenizer.LoadFromSerializedProto(sp_model)
+        return tokenizer
 
-    def __setstate__(self, d):
-        self.__dict__ = d
-        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
-        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+    @property
+    def prefix_token(self):
+        return self._prefix_token
+
+    @property
+    def prefix_id(self):
+        if self._prefix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.prefix_token)
+
+    @property
+    def middle_token(self):
+        return self._middle_token
+
+    @property
+    def middle_id(self):
+        if self._middle_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.middle_token)
+
+    @property
+    def suffix_token(self):
+        return self._suffix_token
+
+    @property
+    def suffix_id(self):
+        if self._suffix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.suffix_token)
+
+    @property
+    def eot_token(self):
+        return self._eot_token
+
+    @property
+    def eot_id(self):
+        if self._eot_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.eot_token)
 
     @property
     def vocab_size(self):
         """Returns vocab size"""
         return self.sp_model.get_piece_size()
 
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
     def get_vocab(self):
         """Returns vocab as a dict"""
         vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
         vocab.update(self.added_tokens_encoder)
         return vocab
 
-    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
-    def tokenize(self, text, **kwargs) -> List[str]:
-        # Replace the SPIECE_UNDERLINE with a space to make sure SPIECE_UNDERLINE is only used at
-        # the beginning of the text
-        if not self.legacy:
-            text = SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " ")
-        return super().tokenize(text, **kwargs)
+    def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> List[int]:
+        # add a prefix space to `prefix`
+        if self.fill_token is not None and self.fill_token in prefix and suffix is None:
+            prefix, suffix = prefix.split(self.fill_token)
+
+        if len(prefix) > 0:
+            prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
+
+        if suffix is None or len(suffix) < 1:
+            tokens = super().tokenize(prefix, **kwargs)
+            if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+                tokens = tokens[1:]
+            return tokens
+
+        prefix_tokens = self._tokenize(prefix)  # prefix has an extra `SPIECE_UNDERLINE`
+
+        if None in (self.prefix_id, self.middle_id, self.suffix_id):
+            raise ValueError(
+                "The input either includes a `prefix` and a `suffix` used for the infilling task,"
+                f"  or can be split on the {self.fill_token} token, creating a suffix and prefix,"
+                " but the model does not support `infilling`."
+            )
+        suffix_tokens = self._tokenize(suffix)  # make sure CodeLlama sp model does not mess up
 
-    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
-    def _tokenize(self, text):
+        suffix_first = suffix_first if suffix_first is not None else self.suffix_first
+        if suffix_first:
+            # format as " 
 {suf}  {pre}"
+            return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
+        else:
+            # format as " 
 {pre} {suf} "
+            return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
+
+    def _tokenize(self, text, **kwargs):
         """
         Returns a tokenized string.
 
-        Since the sentencepiece internal model always adds a SPIECE_UNDERLINE, at the beginning of the provided text,
-        we need to remove it by hand when the current text is a subsequence. This happens whenever the `self.tokenize`
-        function is called with specials tokens: the input is split on the special tokens, and each subsequence is
-        passed to `_tokenize`. Thus if a subsequence did not start with a `" "` or SPIECE_UNDERLINE, we have to remove
-        the extra `SPIECE_UNDERLINE` prepended.
+        We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
+        SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
+        `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
+        `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
+        `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
         """
-        if not self.legacy:
-            is_first = text.startswith(SPIECE_UNDERLINE)
-            if is_first:
-                text = text[1:]
-
         tokens = self.sp_model.encode(text, out_type=str)
+        # 1. Encode string + prefix ex: " Hey"
+        tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+        # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+        return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
 
-        if not self.legacy and not is_first and not text.startswith(" ") and tokens[0].startswith(SPIECE_UNDERLINE):
-            tokens = ([tokens[0][1:]] if len(tokens[0]) > 1 else []) + tokens[1:]
-        return tokens
-
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
     def _convert_token_to_id(self, token):
         """Converts a token (str) in an id using the vocab."""
         return self.sp_model.piece_to_id(token)
 
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
     def _convert_id_to_token(self, index):
         """Converts an index (integer) in a token (str) using the vocab."""
         token = self.sp_model.IdToPiece(index)
@@ -206,23 +313,23 @@ def _convert_id_to_token(self, index):
 
     def convert_tokens_to_string(self, tokens):
         """Converts a sequence of tokens (string) in a single string."""
+        # since we manually add the prefix space, we have to remove it when decoding
+        if tokens[0].startswith(SPIECE_UNDERLINE):
+            tokens[0] = tokens[0][1:]
+
         current_sub_tokens = []
         out_string = ""
-        prev_is_special = False
-        for i, token in enumerate(tokens):
+        for _, token in enumerate(tokens):
             # make sure that special tokens are not decoded using sentencepiece model
             if token in self.all_special_tokens:
-                if not prev_is_special and i != 0:
-                    out_string += " "
                 out_string += self.sp_model.decode(current_sub_tokens) + token
-                prev_is_special = True
                 current_sub_tokens = []
             else:
                 current_sub_tokens.append(token)
-                prev_is_special = False
         out_string += self.sp_model.decode(current_sub_tokens)
         return out_string
 
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
     def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
         """
         Save the vocabulary and special tokens file to a directory.
@@ -250,6 +357,7 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None)
 
         return (out_vocab_file,)
 
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
     def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
         bos_token_id = [self.bos_token_id] if self.add_bos_token else []
         eos_token_id = [self.eos_token_id] if self.add_eos_token else []
@@ -261,6 +369,7 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
 
         return output
 
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
     def get_special_tokens_mask(
         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
     ) -> List[int]:
@@ -298,6 +407,7 @@ def get_special_tokens_mask(
             + eos_token_id
         )
 
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
     def create_token_type_ids_from_sequences(
         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
     ) -> List[int]:
@@ -332,7 +442,7 @@ def create_token_type_ids_from_sequences(
         return output
 
     def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
-        """Builds the input ids for a conversation.
+        r"""Builds the input ids for a conversation.
         This is the format used in the provided examples. System prompts should be manually added at the beginning of
         the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
         ```
@@ -346,8 +456,8 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in
         >>> from transformers import Conversation
 
         >>> Conversation(
-        ...     "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?"
-        ... )
+        ...     "<>\n Complete the functions without any documentation\n<>\n\n `def remove_non_ascii(s: str) -> str:`"
+        ... )  # doctest: +IGNORE_RESULT
         ```
         Args:
             conversation (`Conversation`):
@@ -356,6 +466,21 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in
             `List[int]`:
                 Input ids for the conversation.
         """
+        if self.use_default_system_prompt:
+            if len(conversation.past_user_inputs) > 0:
+                if (
+                    not conversation.past_user_inputs[0].startswith(B_SYS)
+                    or E_SYS not in conversation.past_user_inputs[0]
+                ):
+                    conversation.past_user_inputs[0] = (
+                        B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
+                    )
+            elif conversation.new_user_input:
+                if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
+                    conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
+            else:
+                raise ValueError("Last message must be from user")
+
         dialogue = list(conversation.iter_texts())
         if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
             [not is_user for is_user, msg in dialogue[1::2]]
@@ -365,14 +490,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in
             )
 
         dialog_tokens: List[int] = []
-        if len(conversation.past_user_inputs) > 0:
-            if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
-                conversation.past_user_inputs[0] = (
-                    B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
-                )
-        elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]:
-            dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1])
-
         dialog_tokens += sum(
             [
                 [self.bos_token_id]
@@ -384,9 +501,18 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in
             ],
             [],
         )
-        if not (dialogue[-1][0]):
-            raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}")
         dialog_tokens += [self.bos_token_id] + self.encode(
             f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
         )
         return dialog_tokens
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["sp_model"] = None
+        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+        return state
+
+    def __setstate__(self, d):
+        self.__dict__ = d
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
diff --git a/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama_fast.py b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama_fast.py
new file mode 100644
index 0000000..b492b65
--- /dev/null
+++ b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama_fast.py
@@ -0,0 +1,428 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from tokenizers import normalizers, processors
+
+from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
+from transformers.utils import is_sentencepiece_available, logging
+from transformers.utils.versions import require_version
+
+
+if TYPE_CHECKING:
+    from transformers.pipelines.conversational import Conversation
+
+require_version("tokenizers>=0.13.3")
+
+if is_sentencepiece_available():
+    from .tokenization_code_llama import CodeLlamaTokenizer
+else:
+    CodeLlamaTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
+
+SPIECE_UNDERLINE = "▁"
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
+    """
+    Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+    This uses notably ByteFallback and no normalization.
+
+    ```python
+    >>> from transformers import CodeLlamaTokenizerFast
+
+    >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
+    >>> tokenizer.encode("Hello this is a test")
+    [1, 15043, 445, 338, 263, 1243]
+    ```
+
+    If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
+    call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
+    values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
+    [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
+
+
+    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+    refer to this superclass for more information regarding those methods. The default configuration match that of
+    [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
+    which supports prompt infilling.
+
+    Args:
+        vocab_file (`str`):
+            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
+            contains the vocabulary necessary to instantiate a tokenizer.
+        tokenizer_file (`str`):
+            [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+            contains everything needed to load the tokenizer.
+        clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
+            Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
+            spaces.
+        bos_token (`str`, *optional*, defaults to `""`):
+            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+        eos_token (`str`, *optional*, defaults to `""`):
+            The end of sequence token.
+        unk_token (`str`, *optional*, defaults to `""`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        prefix_token (`str`, *optional*, defaults to `"▁
"`):
+            Prefix token used for infilling.
+        suffix_token (`str`, *optional*, defaults to `"▁"`):
+            Suffix token used for infilling.
+        middle_token (`str`, *optional*, defaults to `"▁"`):
+            Middle token used for infilling.
+        eot_token (`str`, *optional*, defaults to `"▁"`):
+            End of text token used for infilling.
+        fill_token (`str`, *optional*, defaults to `""`):
+            The token used to split the input between the prefix and suffix.
+        suffix_first (`bool`, *optional*, default to `False`):
+            Whether the input prompt and suffix should be formatted with the suffix first.
+        additional_special_tokens (`List[str]`, *optional*):
+            Additional special tokens used by the tokenizer.
+        use_default_system_prompt (`bool`, *optional*, defaults to `True`):
+            Whether or not the default system prompt for Llama should be used.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    slow_tokenizer_class = CodeLlamaTokenizer
+    padding_side = "left"
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file=None,
+        tokenizer_file=None,
+        clean_up_tokenization_spaces=False,
+        unk_token="",
+        bos_token="",
+        eos_token="",
+        prefix_token="▁
",
+        middle_token="▁",
+        suffix_token="▁",
+        eot_token="▁",
+        fill_token="",
+        additional_special_tokens=None,
+        add_bos_token=True,
+        add_eos_token=False,
+        use_default_system_prompt=False,
+        **kwargs,
+    ):
+        # mark tokens special to skip them
+        additional_special_tokens = additional_special_tokens or []
+        for token in [prefix_token, middle_token, suffix_token, eot_token]:
+            additional_special_tokens += [token] if token is not None else []
+        self.use_default_system_prompt = use_default_system_prompt
+
+        super().__init__(
+            vocab_file=vocab_file,
+            tokenizer_file=tokenizer_file,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            additional_special_tokens=additional_special_tokens,
+            unk_token=unk_token,
+            bos_token=bos_token,
+            eos_token=eos_token,
+            prefix_token=prefix_token,
+            middle_token=middle_token,
+            suffix_token=suffix_token,
+            eot_token=eot_token,
+            fill_token=fill_token,
+            use_default_system_prompt=use_default_system_prompt,
+            **kwargs,
+        )
+        self._add_bos_token = add_bos_token
+        self._add_eos_token = add_eos_token
+        self.update_post_processor()
+
+        self.vocab_file = vocab_file
+
+        self._prefix_token = prefix_token
+        self._middle_token = middle_token
+        self._suffix_token = suffix_token
+        self._eot_token = eot_token
+        self.fill_token = fill_token
+
+    @property
+    def can_save_slow_tokenizer(self) -> bool:
+        return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+    # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
+    def update_post_processor(self):
+        """
+        Updates the underlying post processor with the current `bos_token` and `eos_token`.
+        """
+        bos = self.bos_token
+        bos_token_id = self.bos_token_id
+
+        eos = self.eos_token
+        eos_token_id = self.eos_token_id
+
+        single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
+        pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"
+
+        special_tokens = []
+        if self.add_bos_token:
+            special_tokens.append((bos, bos_token_id))
+        if self.add_eos_token:
+            special_tokens.append((eos, eos_token_id))
+        self._tokenizer.post_processor = processors.TemplateProcessing(
+            single=single, pair=pair, special_tokens=special_tokens
+        )
+
+    @property
+    def prefix_token(self):
+        return self._prefix_token
+
+    @property
+    def prefix_id(self):
+        if self._prefix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.prefix_token)
+
+    @property
+    def middle_token(self):
+        return self._middle_token
+
+    @property
+    def middle_id(self):
+        if self._middle_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.middle_token)
+
+    @property
+    def suffix_token(self):
+        return self._suffix_token
+
+    @property
+    def suffix_id(self):
+        if self._suffix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.suffix_token)
+
+    @property
+    def eot_id(self):
+        if self._eot_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.eot_token)
+
+    @property
+    def eot_token(self):
+        return self._eot_token
+
+    @property
+    def add_eos_token(self):
+        return self._add_eos_token
+
+    @property
+    def add_bos_token(self):
+        return self._add_bos_token
+
+    @add_eos_token.setter
+    def add_eos_token(self, value):
+        self._add_eos_token = value
+        self.update_post_processor()
+
+    @add_bos_token.setter
+    def add_bos_token(self, value):
+        self._add_bos_token = value
+        self.update_post_processor()
+
+    def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
+        if reset:
+            self._tokenizer.normalizer = normalizers.Sequence(
+                [
+                    normalizers.Prepend(prepend="▁"),
+                    normalizers.Replace(pattern=" ", content="▁"),
+                ]
+            )
+            self.update_post_processor()
+
+        self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
+        pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
+        special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
+        if suffix_first:
+            # format as " 
 {suf}  {pre}"
+            pair += [self.prefix_token, self.suffix_token, "$A", self.middle_token, "$B"]
+            special_tokens += [
+                (self.prefix_token, self.prefix_id),
+                (self.suffix_token, self.suffix_id),
+                (self.middle_token, self.middle_id),
+            ]
+        else:
+            # format as " 
 {pre} {suf} "
+            pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
+            special_tokens += [
+                (self.prefix_token, self.prefix_id),
+                (self.suffix_token, self.suffix_id),
+                (self.middle_token, self.middle_id),
+            ]
+
+        if self.add_eos_token and add_special_tokens:
+            pair += [self.eos_token]
+            special_tokens += [(self.eos_token, self.eos_token_id)]
+        self._tokenizer.post_processor = processors.TemplateProcessing(
+            single="$A", pair=pair, special_tokens=special_tokens
+        )
+
+    def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
+        # hack to make sure the input is pre-process but outside rust
+        text_pair = kwargs.pop("suffix", text_pair)
+        if self.fill_token in text and text_pair is None:
+            text, text_pair = text.split(self.fill_token)
+
+        if text_pair is None or len(text_pair) < 1:
+            return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
+
+        if None in (self.prefix_id, self.middle_id, self.suffix_id):
+            raise ValueError(
+                "Then input includes a `prefix` and a `suffix` used for the infilling task,"
+                " the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
+                f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
+            )
+
+        self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
+        tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
+        self.set_infilling_processor(True)
+        return tokens
+
+    # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+        if not self.can_save_slow_tokenizer:
+            raise ValueError(
+                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+                "tokenizer."
+            )
+
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+
+        return (out_vocab_file,)
+
+    def build_inputs_with_special_tokens(
+        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+    ) -> List[int]:
+        """
+        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+        adding special tokens. The special tokens depend on calling set_lang.
+
+        An NLLB sequence has the following format, where `X` represents the sequence:
+
+        - `input_ids` (for encoder) `X [eos, src_lang_code]`
+        - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+        BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+        separator.
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of IDs to which the special tokens will be added.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+        """
+        if token_ids_1 is None:
+            return self.bos_token_id + token_ids_0 + self.eos_token_id
+        return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
+
+    # Copied from transformers.models.code_llama.tokenization_code_llama.CodeLlamaTokenizer._build_conversation_input_ids
+    def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
+        r"""Builds the input ids for a conversation.
+        This is the format used in the provided examples. System prompts should be manually added at the beginning of
+        the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
+        ```
+        [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer 
+        [INST] Prompt [/INST] Answer 
+        [INST] Prompt [/INST]
+        ```
+
+        If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
+        ```python
+        >>> from transformers import Conversation
+
+        >>> Conversation(
+        ...     "<>\n Complete the functions without any documentation\n<>\n\n `def remove_non_ascii(s: str) -> str:`"
+        ... )  # doctest: +IGNORE_RESULT
+        ```
+        Args:
+            conversation (`Conversation`):
+                Conversation to build input ids for.
+        Returns:
+            `List[int]`:
+                Input ids for the conversation.
+        """
+        if self.use_default_system_prompt:
+            if len(conversation.past_user_inputs) > 0:
+                if (
+                    not conversation.past_user_inputs[0].startswith(B_SYS)
+                    or E_SYS not in conversation.past_user_inputs[0]
+                ):
+                    conversation.past_user_inputs[0] = (
+                        B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
+                    )
+            elif conversation.new_user_input:
+                if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
+                    conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
+            else:
+                raise ValueError("Last message must be from user")
+
+        dialogue = list(conversation.iter_texts())
+        if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
+            [not is_user for is_user, msg in dialogue[1::2]]
+        ):
+            raise ValueError(
+                "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
+            )
+
+        dialog_tokens: List[int] = []
+        dialog_tokens += sum(
+            [
+                [self.bos_token_id]
+                + self.encode(
+                    f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
+                )
+                + [self.eos_token_id]
+                for prompt, answer in zip(dialogue[::2], dialogue[1::2])
+            ],
+            [],
+        )
+        dialog_tokens += [self.bos_token_id] + self.encode(
+            f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
+        )
+        return dialog_tokens
diff --git a/mft_peft_hf/src/model/code_llama/tokenization_llama.py b/mftcoder_accelerate/src/model/code_llama/tokenization_llama.py
similarity index 92%
rename from mft_peft_hf/src/model/code_llama/tokenization_llama.py
rename to mftcoder_accelerate/src/model/code_llama/tokenization_llama.py
index fb2cfce..98360d3 100644
--- a/mft_peft_hf/src/model/code_llama/tokenization_llama.py
+++ b/mftcoder_accelerate/src/model/code_llama/tokenization_llama.py
@@ -18,7 +18,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Tokenization classes for LLaMA."""
+"""4.33.1 Tokenization classes for LLaMA."""
 import os
 from shutil import copyfile
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -141,7 +141,7 @@ def __init__(
             logger.warning_once(
                 f"You are using the default legacy behaviour of the {self.__class__}. If you see this, DO NOT PANIC! This is"
                 " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
-                " If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it"
+                " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
                 " means, and thouroughly read the reason why this was added as explained in"
                 " https://github.com/huggingface/transformers/pull/24565"
             )
@@ -154,19 +154,25 @@ def __init__(
         self.use_default_system_prompt = use_default_system_prompt
 
         self.sp_model = self.get_spm_processor()
-        self.unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
+
+    @property
+    def unk_token_length(self):
+        return len(self.sp_model.encode(str(self.unk_token)))
 
     # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
     def get_spm_processor(self):
         tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        if self.legacy:  # no dependency on protobuf
+            tokenizer.Load(self.vocab_file)
+            return tokenizer
+
         with open(self.vocab_file, "rb") as f:
             sp_model = f.read()
-            model_pb2 = import_protobuf()
+            model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
             model = model_pb2.ModelProto.FromString(sp_model)
-            if not self.legacy:
-                normalizer_spec = model_pb2.NormalizerSpec()
-                normalizer_spec.add_dummy_prefix = False
-                model.normalizer_spec.MergeFrom(normalizer_spec)
+            normalizer_spec = model_pb2.NormalizerSpec()
+            normalizer_spec.add_dummy_prefix = False
+            model.normalizer_spec.MergeFrom(normalizer_spec)
             sp_model = model.SerializeToString()
             tokenizer.LoadFromSerializedProto(sp_model)
         return tokenizer
@@ -194,18 +200,17 @@ def get_vocab(self):
         return vocab
 
     # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
-    def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
+    def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
         """
         Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
         first token is special.
         """
-        if self.legacy:
+        if self.legacy or len(text) == 0:
             return super().tokenize(text, **kwargs)
 
-        if len(text) > 0:
-            tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
+        tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
 
-        if tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+        if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
             tokens = tokens[1:]
         return tokens
 
@@ -220,13 +225,14 @@ def _tokenize(self, text, **kwargs):
         `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
         `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
         """
-        if self.legacy:
-            return self.sp_model.encode(text, out_type=str)
-
-        unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
-        text = self.unk_token + text
         tokens = self.sp_model.encode(text, out_type=str)
-        return tokens[unk_token_length:]
+        if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
+            return tokens
+
+        # 1. Encode string + prefix ex: " Hey"
+        tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+        # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+        return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
 
     def _convert_token_to_id(self, token):
         """Converts a token (str) in an id using the vocab."""
diff --git a/mft_peft_hf/src/model/code_llama/tokenization_llama_fast.py b/mftcoder_accelerate/src/model/code_llama/tokenization_llama_fast.py
similarity index 98%
rename from mft_peft_hf/src/model/code_llama/tokenization_llama_fast.py
rename to mftcoder_accelerate/src/model/code_llama/tokenization_llama_fast.py
index f7ad2ec..53a4b0b 100644
--- a/mft_peft_hf/src/model/code_llama/tokenization_llama_fast.py
+++ b/mftcoder_accelerate/src/model/code_llama/tokenization_llama_fast.py
@@ -128,7 +128,10 @@ def __init__(
         self.update_post_processor()
         self.use_default_system_prompt = use_default_system_prompt
         self.vocab_file = vocab_file
-        self.can_save_slow_tokenizer = False if not self.vocab_file else True
+
+    @property
+    def can_save_slow_tokenizer(self) -> bool:
+        return os.path.isfile(self.vocab_file) if self.vocab_file else False
 
     def update_post_processor(self):
         """
diff --git a/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py
new file mode 100644
index 0000000..82e0f5d
--- /dev/null
+++ b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py
@@ -0,0 +1,206 @@
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+class DeepseekV2Config(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the DeepSeek-V2.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 102400):
+            Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`DeepseekV2Model`]
+        hidden_size (`int`, *optional*, defaults to 4096):
+            Dimension of the hidden representations.
+        intermediate_size (`int`, *optional*, defaults to 11008):
+            Dimension of the MLP representations.
+        moe_intermediate_size (`int`, *optional*, defaults to 1407):
+            Dimension of the MoE representations.
+        num_hidden_layers (`int`, *optional*, defaults to 32):
+            Number of hidden layers in the Transformer decoder.
+        num_attention_heads (`int`, *optional*, defaults to 32):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        n_shared_experts (`int`, *optional*, defaults to None):
+            Number of shared experts, None means dense model.
+        n_routed_experts (`int`, *optional*, defaults to None):
+            Number of routed experts, None means dense model.
+        routed_scaling_factor (`float`, *optional*, defaults to 1.0):
+            Scaling factor or routed experts.
+        topk_method (`str`, *optional*, defaults to `gready`):
+            Topk method used in routed gate.
+        n_group (`int`, *optional*, defaults to None):
+            Number of groups for routed experts.
+        topk_group (`int`, *optional*, defaults to None):
+            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
+        num_experts_per_tok (`int`, *optional*, defaults to None):
+            Number of selected experts, None means dense model.
+        moe_layer_freq (`int`, *optional*, defaults to 1):
+            The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
+        first_k_dense_replace (`int`, *optional*, defaults to 0):
+            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
+                                                            \--k dense layers--/
+        norm_topk_prob (`bool`, *optional*, defaults to False):
+            Whether to normalize the weights of the routed experts.
+        scoring_func (`str`, *optional*, defaults to 'softmax'):
+            Method of computing expert weights.
+        aux_loss_alpha (`float`, *optional*, defaults to 0.001):
+            Auxiliary loss weight coefficient.
+        seq_aux = (`bool`, *optional*, defaults to True):
+            Whether to compute the auxiliary loss for each individual sample.
+        num_key_value_heads (`int`, *optional*):
+            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+            by meanpooling all the original heads within that group. For more details checkout [this
+            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+            `num_attention_heads`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 2048):
+            The maximum sequence length that this model might ever be used with.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+            The epsilon used by the rms normalization layers.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        pad_token_id (`int`, *optional*):
+            Padding token id.
+        bos_token_id (`int`, *optional*, defaults to 1):
+            Beginning of stream token id.
+        eos_token_id (`int`, *optional*, defaults to 2):
+            End of stream token id.
+        pretraining_tp (`int`, *optional*, defaults to 1):
+            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+            document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+            necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+            issue](https://github.com/pytorch/pytorch/issues/76232).
+        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether to tie weight embeddings
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`Dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+            `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+            `max_position_embeddings` to the expected new maximum.
+        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+            Whether to use a bias in the query, key, value and output projection layers during self-attention.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+
+    ```python
+    >>> from transformers import DeepseekV2Model, DeepseekV2Config
+
+    >>> # Initializing a Deepseek-V2 style configuration
+    >>> configuration = DeepseekV2Config()
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "deepseek_v2"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=102400,
+        hidden_size=4096,
+        intermediate_size=11008,
+        moe_intermediate_size = 1407,
+        num_hidden_layers=30,
+        num_attention_heads=32,
+        num_key_value_heads=32,
+        n_shared_experts = None,
+        n_routed_experts = None,
+        ep_size = 1,
+        routed_scaling_factor = 1.0,
+        kv_lora_rank = 512,
+        q_lora_rank = 1536,
+        qk_rope_head_dim = 64,
+        v_head_dim = 128,
+        qk_nope_head_dim = 128,
+        topk_method = 'gready',
+        n_group = None,
+        topk_group = None,
+        num_experts_per_tok = None,
+        moe_layer_freq = 1,
+        first_k_dense_replace = 0,
+        norm_topk_prob = False,
+        scoring_func = 'softmax',
+        aux_loss_alpha = 0.001,
+        seq_aux = True,
+        hidden_act="silu",
+        max_position_embeddings=2048,
+        initializer_range=0.02,
+        rms_norm_eps=1e-6,
+        use_cache=True,
+        pad_token_id=None,
+        bos_token_id=100000,
+        eos_token_id=100001,
+        pretraining_tp=1,
+        tie_word_embeddings=False,
+        rope_theta=10000.0,
+        rope_scaling=None,
+        attention_bias=False,
+        attention_dropout=0.0,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.moe_intermediate_size = moe_intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.n_shared_experts = n_shared_experts
+        self.n_routed_experts = n_routed_experts
+        self.ep_size = ep_size
+        self.routed_scaling_factor = routed_scaling_factor
+        self.kv_lora_rank = kv_lora_rank
+        self.q_lora_rank = q_lora_rank
+        self.qk_rope_head_dim = qk_rope_head_dim
+        self.v_head_dim = v_head_dim
+        self.qk_nope_head_dim = qk_nope_head_dim
+        self.topk_method = topk_method
+        self.n_group = n_group
+        self.topk_group = topk_group
+        self.num_experts_per_tok = num_experts_per_tok
+        self.moe_layer_freq = moe_layer_freq
+        self.first_k_dense_replace = first_k_dense_replace
+        self.norm_topk_prob = norm_topk_prob
+        self.scoring_func = scoring_func
+        self.aux_loss_alpha = aux_loss_alpha
+        self.seq_aux = seq_aux
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.pretraining_tp = pretraining_tp
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        self.attention_bias = attention_bias
+        self.attention_dropout = attention_dropout
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
\ No newline at end of file
diff --git a/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py
new file mode 100644
index 0000000..d1d5e88
--- /dev/null
+++ b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py
@@ -0,0 +1,1925 @@
+# coding=utf-8
+# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch DeepSeek model."""
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import (
+    AttentionMaskConverter,
+    _prepare_4d_attention_mask,
+    _prepare_4d_causal_attention_mask,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPast,
+    CausalLMOutputWithPast,
+    SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import (
+    ALL_LAYERNORM_LAYERS,
+    is_torch_greater_or_equal_than_1_13,
+)
+from transformers.utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    is_flash_attn_2_available,
+    is_flash_attn_greater_or_equal_2_10,
+    logging,
+    replace_return_docstrings,
+)
+from transformers.utils.import_utils import is_torch_fx_available
+from .configuration_deepseek import DeepseekV2Config
+import torch.distributed as dist
+import numpy as np
+
+if is_flash_attn_2_available():
+    from flash_attn import flash_attn_func, flash_attn_varlen_func
+    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
+
+
+# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
+# It means that the function will not be traced through and simply appear as a node in the graph.
+if is_torch_fx_available():
+    if not is_torch_greater_or_equal_than_1_13:
+        import torch.fx
+
+    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DeepseekV2Config"
+
+
+def _get_unpad_data(attention_mask):
+    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+    max_seqlen_in_batch = seqlens_in_batch.max().item()
+    cu_seqlens = F.pad(
+        torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
+    )
+    return (
+        indices,
+        cu_seqlens,
+        max_seqlen_in_batch,
+    )
+
+
+class DeepseekV2RMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        DeepseekV2RMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+        return self.weight * hidden_states.to(input_dtype)
+
+
+ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
+
+
+class DeepseekV2RotaryEmbedding(nn.Module):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+        super().__init__()
+
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        inv_freq = 1.0 / (
+            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
+        )
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+        # Build here to make `torch.jit.trace` work.
+        self._set_cos_sin_cache(
+            seq_len=max_position_embeddings,
+            device=self.inv_freq.device,
+            dtype=torch.get_default_dtype(),
+        )
+        self.max_seq_len_cached = None
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        self.max_seq_len_cached = seq_len
+        t = torch.arange(
+            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
+        )
+
+        freqs = torch.outer(t, self.inv_freq.to(t.device))
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
+            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+        return (
+            self.cos_cached[:seq_len].to(dtype=x.dtype),
+            self.sin_cached[:seq_len].to(dtype=x.dtype),
+        )
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
+class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
+    """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+    def __init__(
+        self,
+        dim,
+        max_position_embeddings=2048,
+        base=10000,
+        device=None,
+        scaling_factor=1.0,
+    ):
+        self.scaling_factor = scaling_factor
+        super().__init__(dim, max_position_embeddings, base, device)
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        self.max_seq_len_cached = seq_len
+        t = torch.arange(
+            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
+        )
+        t = t / self.scaling_factor
+
+        freqs = torch.outer(t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2
+class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
+    """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+    def __init__(
+        self,
+        dim,
+        max_position_embeddings=2048,
+        base=10000,
+        device=None,
+        scaling_factor=1.0,
+    ):
+        self.scaling_factor = scaling_factor
+        super().__init__(dim, max_position_embeddings, base, device)
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        self.max_seq_len_cached = seq_len
+
+        if seq_len > self.max_position_embeddings:
+            base = self.base * (
+                (self.scaling_factor * seq_len / self.max_position_embeddings)
+                - (self.scaling_factor - 1)
+            ) ** (self.dim / (self.dim - 2))
+            inv_freq = 1.0 / (
+                base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
+            )
+            self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+        t = torch.arange(
+            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
+        )
+
+        freqs = torch.outer(t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+
+# Inverse dim formula to find dim based on number of rotations
+def yarn_find_correction_dim(
+    num_rotations, dim, base=10000, max_position_embeddings=2048
+):
+    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
+        2 * math.log(base)
+    )
+
+
+# Find dim range bounds based on rotations
+def yarn_find_correction_range(
+    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
+):
+    low = math.floor(
+        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
+    )
+    high = math.ceil(
+        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
+    )
+    return max(low, 0), min(high, dim - 1)  # Clamp values just in case
+
+
+def yarn_get_mscale(scale=1, mscale=1):
+    if scale <= 1:
+        return 1.0
+    return 0.1 * mscale * math.log(scale) + 1.0
+
+
+def yarn_linear_ramp_mask(min, max, dim):
+    if min == max:
+        max += 0.001  # Prevent singularity
+
+    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
+    ramp_func = torch.clamp(linear_func, 0, 1)
+    return ramp_func
+
+
+class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):
+
+    def __init__(
+        self,
+        dim,
+        max_position_embeddings=2048,
+        base=10000,
+        device=None,
+        scaling_factor=1.0,
+        original_max_position_embeddings=4096,
+        beta_fast=32,
+        beta_slow=1,
+        mscale=1,
+        mscale_all_dim=0,
+    ):
+        self.scaling_factor = scaling_factor
+        self.original_max_position_embeddings = original_max_position_embeddings
+        self.beta_fast = beta_fast
+        self.beta_slow = beta_slow
+        self.mscale = mscale
+        self.mscale_all_dim = mscale_all_dim
+        super().__init__(dim, max_position_embeddings, base, device)
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        self.max_seq_len_cached = seq_len
+        dim = self.dim
+
+        freq_extra = 1.0 / (
+            self.base
+            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
+        )
+        freq_inter = 1.0 / (
+            self.scaling_factor
+            * self.base
+            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
+        )
+
+        low, high = yarn_find_correction_range(
+            self.beta_fast,
+            self.beta_slow,
+            dim,
+            self.base,
+            self.original_max_position_embeddings,
+        )
+        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
+            device=device, dtype=torch.float32
+        )
+        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+        t = torch.arange(seq_len, device=device, dtype=torch.float32)
+
+        freqs = torch.outer(t, inv_freq)
+
+        _mscale = float(
+            yarn_get_mscale(self.scaling_factor, self.mscale)
+            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
+        )
+
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer(
+            "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
+        )
+        self.register_buffer(
+            "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
+        )
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`):
+            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+            used to pass offsetted position ids when working with a KV-cache.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+
+    b, h, s, d = q.shape
+    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+    b, h, s, d = k.shape
+    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class DeepseekV2MLP(nn.Module):
+    def __init__(self, config, hidden_size=None, intermediate_size=None):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+        self.intermediate_size = (
+            config.intermediate_size if intermediate_size is None else intermediate_size
+        )
+
+        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, x):
+        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+        return down_proj
+
+
+class MoEGate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.top_k = config.num_experts_per_tok
+        self.n_routed_experts = config.n_routed_experts
+        self.routed_scaling_factor = config.routed_scaling_factor
+        self.scoring_func = config.scoring_func
+        self.alpha = config.aux_loss_alpha
+        self.seq_aux = config.seq_aux
+        self.topk_method = config.topk_method
+        self.n_group = config.n_group
+        self.topk_group = config.topk_group
+
+        # topk selection algorithm
+        self.norm_topk_prob = config.norm_topk_prob
+        self.gating_dim = config.hidden_size
+        self.weight = nn.Parameter(
+            torch.empty((self.n_routed_experts, self.gating_dim))
+        )
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        import torch.nn.init as init
+
+        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+
+    def forward(self, hidden_states):
+        bsz, seq_len, h = hidden_states.shape
+        ### compute gating score
+        hidden_states = hidden_states.view(-1, h)
+        logits = F.linear(
+            hidden_states.type(torch.float32), self.weight.type(torch.float32), None
+        )
+        if self.scoring_func == "softmax":
+            scores = logits.softmax(dim=-1, dtype=torch.float32)
+        else:
+            raise NotImplementedError(
+                f"insupportable scoring function for MoE gating: {self.scoring_func}"
+            )
+
+        ### select top-k experts
+        if self.topk_method == "greedy":
+            topk_weight, topk_idx = torch.topk(
+                scores, k=self.top_k, dim=-1, sorted=False
+            )
+        elif self.topk_method == "group_limited_greedy":
+            group_scores = (
+                scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
+            )  # [n, n_group]
+            group_idx = torch.topk(
+                group_scores, k=self.topk_group, dim=-1, sorted=False
+            )[
+                1
+            ]  # [n, top_k_group]
+            group_mask = torch.zeros_like(group_scores)  # [n, n_group]
+            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
+            score_mask = (
+                group_mask.unsqueeze(-1)
+                .expand(
+                    bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
+                )
+                .reshape(bsz * seq_len, -1)
+            )  # [n, e]
+            tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]
+            topk_weight, topk_idx = torch.topk(
+                tmp_scores, k=self.top_k, dim=-1, sorted=False
+            )
+
+        ### norm gate to sum 1
+        if self.top_k > 1 and self.norm_topk_prob:
+            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
+            topk_weight = topk_weight / denominator
+        else:
+            topk_weight = topk_weight * self.routed_scaling_factor
+        ### expert-level computation auxiliary loss
+        if self.training and self.alpha > 0.0:
+            scores_for_aux = scores
+            aux_topk = self.top_k
+            # always compute aux loss based on the naive greedy topk method
+            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
+            if self.seq_aux:
+                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
+                ce = torch.zeros(
+                    bsz, self.n_routed_experts, device=hidden_states.device
+                )
+                ce.scatter_add_(
+                    1,
+                    topk_idx_for_aux_loss,
+                    torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
+                ).div_(seq_len * aux_topk / self.n_routed_experts)
+                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
+                    dim=1
+                ).mean() * self.alpha
+            else:
+                mask_ce = F.one_hot(
+                    topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
+                )
+                ce = mask_ce.float().mean(0)
+                Pi = scores_for_aux.mean(0)
+                fi = ce * self.n_routed_experts
+                aux_loss = (Pi * fi).sum() * self.alpha
+        else:
+            aux_loss = None
+        return topk_idx, topk_weight, aux_loss
+
+
+class AddAuxiliaryLoss(torch.autograd.Function):
+    """
+    The trick function of adding auxiliary (aux) loss,
+    which includes the gradient of the aux loss during backpropagation.
+    """
+
+    @staticmethod
+    def forward(ctx, x, loss):
+        assert loss.numel() == 1
+        ctx.dtype = loss.dtype
+        ctx.required_aux_loss = loss.requires_grad
+        return x
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_loss = None
+        if ctx.required_aux_loss:
+            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
+        return grad_output, grad_loss
+
+
+class DeepseekV2MoE(nn.Module):
+    """
+    A mixed expert module containing shared experts.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.num_experts_per_tok = config.num_experts_per_tok
+
+        if hasattr(config, "ep_size") and config.ep_size > 1:
+            assert config.ep_size == dist.get_world_size()
+            self.ep_size = config.ep_size
+            self.experts_per_rank = config.n_routed_experts // config.ep_size
+            self.ep_rank = dist.get_rank()
+            self.experts = nn.ModuleList(
+                [
+                    (
+                        DeepseekV2MLP(
+                            config, intermediate_size=config.moe_intermediate_size
+                        )
+                        if i >= self.ep_rank * self.experts_per_rank
+                        and i < (self.ep_rank + 1) * self.experts_per_rank
+                        else None
+                    )
+                    for i in range(config.n_routed_experts)
+                ]
+            )
+        else:
+            self.ep_size = 1
+            self.experts_per_rank = config.n_routed_experts
+            self.ep_rank = 0
+            self.experts = nn.ModuleList(
+                [
+                    DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)
+                    for i in range(config.n_routed_experts)
+                ]
+            )
+        self.gate = MoEGate(config)
+        if config.n_shared_experts is not None:
+            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
+            self.shared_experts = DeepseekV2MLP(
+                config=config, intermediate_size=intermediate_size
+            )
+
+    def forward(self, hidden_states):
+        # save dtype before computation
+        input_dtype = hidden_states.dtype
+        identity = hidden_states
+        orig_shape = hidden_states.shape
+        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
+        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+        flat_topk_idx = topk_idx.view(-1)
+        if self.training:
+            hidden_states = hidden_states.repeat_interleave(
+                self.num_experts_per_tok, dim=0
+            )
+            y = torch.empty_like(hidden_states)
+            for i, expert in enumerate(self.experts):
+                y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
+            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
+            y = y.view(*orig_shape)
+            y = AddAuxiliaryLoss.apply(y, aux_loss)
+        else:
+            y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
+        if self.config.n_shared_experts is not None:
+            y = y + self.shared_experts(identity)
+        # keep dtype same after moe forward
+        return y.to(input_dtype)
+
+    @torch.no_grad()
+    def moe_infer(self, x, topk_ids, topk_weight):
+        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
+        cnts.scatter_(1, topk_ids, 1)
+        tokens_per_expert = cnts.sum(dim=0)
+        idxs = topk_ids.view(-1).argsort()
+        sorted_tokens = x[idxs // topk_ids.shape[1]]
+        sorted_tokens_shape = sorted_tokens.shape
+        if self.ep_size > 1:
+            tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
+            tokens_per_expert_group = tokens_per_expert.new_empty(
+                tokens_per_expert.shape[0]
+            )
+            dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
+            output_splits = (
+                tokens_per_expert_group.view(self.ep_size, -1)
+                .sum(1)
+                .cpu()
+                .numpy()
+                .tolist()
+            )
+            gathered_tokens = sorted_tokens.new_empty(
+                tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
+            )
+            input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
+            dist.all_to_all(
+                list(gathered_tokens.split(output_splits)),
+                list(sorted_tokens.split(input_split_sizes)),
+            )
+            tokens_per_expert_post_gather = tokens_per_expert_group.view(
+                self.ep_size, self.experts_per_rank
+            ).sum(dim=0)
+            gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
+            s = 0
+            for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
+                gatherd_idxs[s : s + k] = i % self.experts_per_rank
+                s += k
+            gatherd_idxs = gatherd_idxs.argsort()
+            sorted_tokens = gathered_tokens[gatherd_idxs]
+            tokens_per_expert = tokens_per_expert_post_gather
+        tokens_per_expert = tokens_per_expert.cpu().numpy()
+
+        outputs = []
+        start_idx = 0
+        for i, num_tokens in enumerate(tokens_per_expert):
+            end_idx = start_idx + num_tokens
+            if num_tokens == 0:
+                continue
+            expert = self.experts[i + self.ep_rank * self.experts_per_rank]
+            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
+            expert_out = expert(tokens_for_this_expert)
+            outputs.append(expert_out)
+            start_idx = end_idx
+
+        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
+        if self.ep_size > 1:
+            new_x = torch.empty_like(outs)
+            new_x[gatherd_idxs] = outs
+            gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
+            dist.all_to_all(
+                list(gathered_tokens.split(input_split_sizes)),
+                list(new_x.split(output_splits)),
+            )
+            outs = gathered_tokens
+
+        new_x = torch.empty_like(outs)
+        new_x[idxs] = outs
+        final_out = (
+            new_x.view(*topk_ids.shape, -1)
+            .type(topk_weight.dtype)
+            .mul_(topk_weight.unsqueeze(dim=-1))
+            .sum(dim=1)
+            .type(new_x.dtype)
+        )
+        return final_out
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(
+        batch, num_key_value_heads, n_rep, slen, head_dim
+    )
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2
+class DeepseekV2Attention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.attention_dropout = config.attention_dropout
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.q_lora_rank = config.q_lora_rank
+        self.qk_rope_head_dim = config.qk_rope_head_dim
+        self.kv_lora_rank = config.kv_lora_rank
+        self.v_head_dim = config.v_head_dim
+        self.qk_nope_head_dim = config.qk_nope_head_dim
+        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
+
+        self.is_causal = True
+
+        if self.q_lora_rank is None:
+            self.q_proj = nn.Linear(
+                self.hidden_size, self.num_heads * self.q_head_dim, bias=False
+            )
+        else:
+            self.q_a_proj = nn.Linear(
+                self.hidden_size, config.q_lora_rank, bias=config.attention_bias
+            )
+            self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
+            self.q_b_proj = nn.Linear(
+                config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
+            )
+
+        self.kv_a_proj_with_mqa = nn.Linear(
+            self.hidden_size,
+            config.kv_lora_rank + config.qk_rope_head_dim,
+            bias=config.attention_bias,
+        )
+        self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
+        self.kv_b_proj = nn.Linear(
+            config.kv_lora_rank,
+            self.num_heads
+            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
+            bias=False,
+        )
+
+        self.o_proj = nn.Linear(
+            self.num_heads * self.v_head_dim,
+            self.hidden_size,
+            bias=config.attention_bias,
+        )
+        self._init_rope()
+
+        self.softmax_scale = self.q_head_dim ** (-0.5)
+        if self.config.rope_scaling is not None:
+            mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+            scaling_factor = self.config.rope_scaling["factor"]
+            if mscale_all_dim:
+                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+                self.softmax_scale = self.softmax_scale * mscale * mscale
+
+    def _init_rope(self):
+        if self.config.rope_scaling is None:
+            self.rotary_emb = DeepseekV2RotaryEmbedding(
+                self.qk_rope_head_dim,
+                max_position_embeddings=self.max_position_embeddings,
+                base=self.rope_theta,
+            )
+        else:
+            scaling_type = self.config.rope_scaling["type"]
+            scaling_factor = self.config.rope_scaling["factor"]
+            if scaling_type == "linear":
+                self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
+                    self.qk_rope_head_dim,
+                    max_position_embeddings=self.max_position_embeddings,
+                    scaling_factor=scaling_factor,
+                    base=self.rope_theta,
+                )
+            elif scaling_type == "dynamic":
+                self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
+                    self.qk_rope_head_dim,
+                    max_position_embeddings=self.max_position_embeddings,
+                    scaling_factor=scaling_factor,
+                    base=self.rope_theta,
+                )
+            elif scaling_type == "yarn":
+                kwargs = {
+                    key: self.config.rope_scaling[key]
+                    for key in [
+                        "original_max_position_embeddings",
+                        "beta_fast",
+                        "beta_slow",
+                        "mscale",
+                        "mscale_all_dim",
+                    ]
+                    if key in self.config.rope_scaling
+                }
+                self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
+                    self.qk_rope_head_dim,
+                    max_position_embeddings=self.max_position_embeddings,
+                    scaling_factor=scaling_factor,
+                    base=self.rope_theta,
+                    **kwargs,
+                )
+            else:
+                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return (
+            tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
+            .transpose(1, 2)
+            .contiguous()
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        if "padding_mask" in kwargs:
+            warnings.warn(
+                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+            )
+        bsz, q_len, _ = hidden_states.size()
+
+        if self.q_lora_rank is None:
+            q = self.q_proj(hidden_states)
+        else:
+            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
+        q_nope, q_pe = torch.split(
+            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+        )
+
+        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+        compressed_kv, k_pe = torch.split(
+            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+        )
+        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
+        kv = (
+            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
+            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
+            .transpose(1, 2)
+        )
+
+        k_nope, value_states = torch.split(
+            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
+        )
+        kv_seq_len = value_states.shape[-2]
+        if past_key_value is not None:
+            if self.layer_idx is None:
+                raise ValueError(
+                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+                    "with a layer index."
+                )
+            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
+
+        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
+        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
+
+        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
+        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
+        if past_key_value is not None:
+            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
+            key_states, value_states = past_key_value.update(
+                key_states, value_states, self.layer_idx, cache_kwargs
+            )
+
+        attn_weights = (
+            torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
+        )
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+        assert attention_mask is not None
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(
+            attn_weights, dim=-1, dtype=torch.float32
+        ).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(
+            attn_weights, p=self.attention_dropout, training=self.training
+        )
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+
+        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
+
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2
+class DeepseekV2FlashAttention2(DeepseekV2Attention):
+    """
+    DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+    flash attention and deal with padding tokens in case the input contains any of them.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        # DeepseekV2FlashAttention2 attention does not support output_attentions
+        if "padding_mask" in kwargs:
+            warnings.warn(
+                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+            )
+
+            # overwrite attention_mask with padding_mask
+            attention_mask = kwargs.pop("padding_mask")
+
+        output_attentions = False
+
+        bsz, q_len, _ = hidden_states.size()
+        
+        if self.q_lora_rank is None:
+            # print(f"dtype of hidden_states: {hidden_states.dtype}")
+            # print(f"dtype of q_proj: {self.q_proj.weight.dtype}")
+            q = self.q_proj(hidden_states)
+        else:
+            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
+        q_nope, q_pe = torch.split(
+            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+        )
+
+        # Flash attention requires the input to have the shape
+        # batch_size x seq_length x head_dim x hidden_dim
+        # therefore we just need to keep the original shape
+        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+        compressed_kv, k_pe = torch.split(
+            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+        )
+        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
+        kv = (
+            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
+            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
+            .transpose(1, 2)
+        )
+
+        k_nope, value_states = torch.split(
+            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
+        )
+        kv_seq_len = value_states.shape[-2]
+
+        kv_seq_len = value_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
+
+        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
+        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
+
+        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
+        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
+        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
+
+        if self.q_head_dim != self.v_head_dim:
+            value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
+
+        if past_key_value is not None:
+            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
+            key_states, value_states = past_key_value.update(
+                key_states, value_states, self.layer_idx, cache_kwargs
+            )
+
+        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+        # to be able to avoid many of these transpose/reshape/view.
+        query_states = query_states.transpose(1, 2)
+        key_states = key_states.transpose(1, 2)
+        value_states = value_states.transpose(1, 2)
+
+        dropout_rate = self.attention_dropout if self.training else 0.0
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (DeepseekV2RMSNorm handles it correctly)
+
+        input_dtype = query_states.dtype
+        if input_dtype == torch.float32:
+            # Handle the case where the model is quantized
+            if hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            elif torch.is_autocast_enabled():
+                target_dtype = torch.get_autocast_gpu_dtype()
+            else:
+                target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype
+
+            logger.warning_once(
+                f"The input hidden states seems to be silently casted in float32, this might be related to"
+                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+                f" {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        attn_output = self._flash_attention_forward(
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            q_len,
+            dropout=dropout_rate,
+            softmax_scale=self.softmax_scale,
+        )
+        if self.q_head_dim != self.v_head_dim:
+            attn_output = attn_output[:, :, :, : self.v_head_dim]
+
+        attn_output = attn_output.reshape(
+            bsz, q_len, self.num_heads * self.v_head_dim
+        ).contiguous()
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+    def _flash_attention_forward(
+        self,
+        query_states,
+        key_states,
+        value_states,
+        attention_mask,
+        query_length,
+        dropout=0.0,
+        softmax_scale=None,
+    ):
+        """
+        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+        first unpad the input, then computes the attention scores and pad the final attention scores.
+
+        Args:
+            query_states (`torch.Tensor`):
+                Input query states to be passed to Flash Attention API
+            key_states (`torch.Tensor`):
+                Input key states to be passed to Flash Attention API
+            value_states (`torch.Tensor`):
+                Input value states to be passed to Flash Attention API
+            attention_mask (`torch.Tensor`):
+                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+                position of padding tokens and 1 for the position of non-padding tokens.
+            dropout (`int`, *optional*):
+                Attention dropout
+            softmax_scale (`float`, *optional*):
+                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+        """
+        if not self._flash_attn_uses_top_left_mask:
+            causal = self.is_causal
+        else:
+            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
+            causal = self.is_causal and query_length != 1
+
+        # Contains at least one padding token in the sequence
+        if attention_mask is not None:
+            batch_size = query_states.shape[0]
+            (
+                query_states,
+                key_states,
+                value_states,
+                indices_q,
+                cu_seq_lens,
+                max_seq_lens,
+            ) = self._upad_input(
+                query_states, key_states, value_states, attention_mask, query_length
+            )
+
+            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+            attn_output_unpad = flash_attn_varlen_func(
+                query_states,
+                key_states,
+                value_states,
+                cu_seqlens_q=cu_seqlens_q,
+                cu_seqlens_k=cu_seqlens_k,
+                max_seqlen_q=max_seqlen_in_batch_q,
+                max_seqlen_k=max_seqlen_in_batch_k,
+                dropout_p=dropout,
+                softmax_scale=softmax_scale,
+                causal=causal,
+            )
+
+            attn_output = pad_input(
+                attn_output_unpad, indices_q, batch_size, query_length
+            )
+        else:
+            attn_output = flash_attn_func(
+                query_states,
+                key_states,
+                value_states,
+                dropout,
+                softmax_scale=softmax_scale,
+                causal=causal,
+            )
+
+        return attn_output
+
+    def _upad_input(
+        self, query_layer, key_layer, value_layer, attention_mask, query_length
+    ):
+        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+        key_layer = index_first_axis(
+            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+            indices_k,
+        )
+        value_layer = index_first_axis(
+            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+            indices_k,
+        )
+        if query_length == kv_seq_len:
+            query_layer = index_first_axis(
+                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
+                indices_k,
+            )
+            cu_seqlens_q = cu_seqlens_k
+            max_seqlen_in_batch_q = max_seqlen_in_batch_k
+            indices_q = indices_k
+        elif query_length == 1:
+            max_seqlen_in_batch_q = 1
+            cu_seqlens_q = torch.arange(
+                batch_size + 1, dtype=torch.int32, device=query_layer.device
+            )  # There is a memcpy here, that is very bad.
+            indices_q = cu_seqlens_q[:-1]
+            query_layer = query_layer.squeeze(1)
+        else:
+            # The -q_len: slice assumes left padding.
+            attention_mask = attention_mask[:, -query_length:]
+            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
+                query_layer, attention_mask
+            )
+
+        return (
+            query_layer,
+            key_layer,
+            value_layer,
+            indices_q,
+            (cu_seqlens_q, cu_seqlens_k),
+            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+        )
+
+
+ATTENTION_CLASSES = {
+    "eager": DeepseekV2Attention,
+    "flash_attention_2": DeepseekV2FlashAttention2,
+}
+
+
+class DeepseekV2DecoderLayer(nn.Module):
+    def __init__(self, config: DeepseekV2Config, layer_idx: int):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+
+        self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
+            config=config, layer_idx=layer_idx
+        )
+
+        self.mlp = (
+            DeepseekV2MoE(config)
+            if (
+                config.n_routed_experts is not None
+                and layer_idx >= config.first_k_dense_replace
+                and layer_idx % config.moe_layer_freq == 0
+            )
+            else DeepseekV2MLP(config)
+        )
+        self.input_layernorm = DeepseekV2RMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps
+        )
+        self.post_attention_layernorm = DeepseekV2RMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+        **kwargs,
+    ) -> Tuple[
+        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+    ]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*):
+                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+                query_sequence_length, key_sequence_length)` if default attention is used.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+        """
+        if "padding_mask" in kwargs:
+            warnings.warn(
+                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+            )
+        residual = hidden_states
+        # print(f"1. dtype of residual: {residual.dtype}")
+
+        hidden_states = self.input_layernorm(hidden_states)
+        # print(f"2. dtype of hidden_states before attn: {hidden_states.dtype}")
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            **kwargs,
+        )
+        # print(f"3. dtype of hidden_states after attn: {hidden_states.dtype}")
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        # print(f"4. dtype of hidden_states after post layernorm: {hidden_states.dtype}")
+        hidden_states = self.mlp(hidden_states)
+        # print(f"5. dtype of hidden_states after mlp: {hidden_states.dtype}")
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+DeepseekV2_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`DeepseekV2Config`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
+    DeepseekV2_START_DOCSTRING,
+)
+class DeepseekV2PreTrainedModel(PreTrainedModel):
+    config_class = DeepseekV2Config
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["DeepseekV2DecoderLayer"]
+    _skip_keys_device_placement = "past_key_values"
+    _supports_flash_attn_2 = True
+    _supports_cache_class = True
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+DeepseekV2_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+            `past_key_values`).
+
+            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+            information on the default strategy.
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.n_positions - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+            Two formats are allowed:
+            - a [`~cache_utils.Cache`] instance;
+            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+            cache format.
+
+            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+            legacy cache format will be returned.
+
+            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+            of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
+    DeepseekV2_START_DOCSTRING,
+)
+class DeepseekV2Model(DeepseekV2PreTrainedModel):
+    """
+    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
+
+    Args:
+        config: DeepseekV2Config
+    """
+
+    def __init__(self, config: DeepseekV2Config):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = nn.Embedding(
+            config.vocab_size, config.hidden_size, self.padding_idx
+        )
+        self.layers = nn.ModuleList(
+            [
+                DeepseekV2DecoderLayer(config, layer_idx)
+                for layer_idx in range(config.num_hidden_layers)
+            ]
+        )
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both input_ids and inputs_embeds at the same time"
+            )
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape[:2]
+        elif inputs_embeds is not None:
+            batch_size, seq_length = inputs_embeds.shape[:2]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
+                )
+                use_cache = False
+
+        past_key_values_length = 0
+        if use_cache:
+            use_legacy_cache = not isinstance(past_key_values, Cache)
+            if use_legacy_cache:
+                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+            past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length,
+                seq_length + past_key_values_length,
+                dtype=torch.long,
+                device=device,
+            )
+            position_ids = position_ids.unsqueeze(0)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if self._use_flash_attention_2:
+            # 2d mask is passed through the layers
+            attention_mask = (
+                attention_mask
+                if (attention_mask is not None and 0 in attention_mask)
+                else None
+            )
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask,
+                (batch_size, seq_length),
+                inputs_embeds,
+                past_key_values_length,
+            )
+
+        # embed positions
+        hidden_states = inputs_embeds
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = None
+
+        for decoder_layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    past_key_values,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_values,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = None
+        if use_cache:
+            next_cache = (
+                next_decoder_cache.to_legacy_cache()
+                if use_legacy_cache
+                else next_decoder_cache
+            )
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+                if v is not None
+            )
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+
+class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.model = DeepseekV2Model(config)
+        self.vocab_size = config.vocab_size
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
+    @replace_return_docstrings(
+        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+    )
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        r"""
+        Args:
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
+
+        >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+        >>> prompt = "Hey, are you conscious? Can you talk to me?"
+        >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >>> # Generate
+        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+        ```"""
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0]
+        logits = self.lm_head(hidden_states)
+        logits = logits.float()
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            shift_logits = shift_logits.view(-1, self.config.vocab_size)
+            shift_labels = shift_labels.view(-1)
+            # Enable model parallelism
+            shift_labels = shift_labels.to(shift_logits.device)
+            loss = loss_fct(shift_logits, shift_labels)
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        past_key_values=None,
+        attention_mask=None,
+        inputs_embeds=None,
+        **kwargs,
+    ):
+        if past_key_values is not None:
+            if isinstance(past_key_values, Cache):
+                cache_length = past_key_values.get_seq_length()
+                past_length = past_key_values.seen_tokens
+                max_cache_length = past_key_values.get_max_length()
+            else:
+                cache_length = past_length = past_key_values[0][0].shape[2]
+                max_cache_length = None
+
+            # Keep only the unprocessed tokens:
+            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+            # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
+            # input)
+            if (
+                attention_mask is not None
+                and attention_mask.shape[1] > input_ids.shape[1]
+            ):
+                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+            # input_ids based on the past_length.
+            elif past_length < input_ids.shape[1]:
+                input_ids = input_ids[:, past_length:]
+            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+            if (
+                max_cache_length is not None
+                and attention_mask is not None
+                and cache_length + input_ids.shape[1] > max_cache_length
+            ):
+                attention_mask = attention_mask[:, -max_cache_length:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -input_ids.shape[1] :]
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            reordered_past += (
+                tuple(
+                    past_state.index_select(0, beam_idx.to(past_state.device))
+                    for past_state in layer_past
+                ),
+            )
+        return reordered_past
+
+
+@add_start_docstrings(
+    """
+    The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).
+
+    [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+    (e.g. GPT-2) do.
+
+    Since it does classification on the last token, it requires to know the position of the last token. If a
+    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+    each row of the batch).
+    """,
+    DeepseekV2_START_DOCSTRING,
+)
+class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.model = DeepseekV2Model(config)
+        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        transformer_outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = transformer_outputs[0]
+        logits = self.score(hidden_states)
+
+        if input_ids is not None:
+            batch_size = input_ids.shape[0]
+        else:
+            batch_size = inputs_embeds.shape[0]
+
+        if self.config.pad_token_id is None and batch_size != 1:
+            raise ValueError(
+                "Cannot handle batch sizes > 1 if no padding token is defined."
+            )
+        if self.config.pad_token_id is None:
+            sequence_lengths = -1
+        else:
+            if input_ids is not None:
+                sequence_lengths = (
+                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+                ).to(logits.device)
+            else:
+                sequence_lengths = -1
+
+        pooled_logits = logits[
+            torch.arange(batch_size, device=logits.device), sequence_lengths
+        ]
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (
+                    labels.dtype == torch.long or labels.dtype == torch.int
+                ):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(pooled_logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(
+                    pooled_logits.view(-1, self.num_labels), labels.view(-1)
+                )
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(pooled_logits, labels)
+        if not return_dict:
+            output = (pooled_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutputWithPast(
+            loss=loss,
+            logits=pooled_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
diff --git a/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py
new file mode 100644
index 0000000..d243771
--- /dev/null
+++ b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py
@@ -0,0 +1,38 @@
+from typing import List, Optional, Union
+
+
+from transformers.models.llama import LlamaTokenizerFast
+
+
+class DeepseekTokenizerFast(LlamaTokenizerFast):
+
+    def convert_ids_to_tokens(
+        self, ids: Union[int, List[int]], skip_special_tokens: bool = False
+    ) -> Union[str, List[str]]:
+        """
+        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
+        added tokens.
+
+        Args:
+            ids (`int` or `List[int]`):
+                The token id (or token ids) to convert to tokens.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+
+        Returns:
+            `str` or `List[str]`: The decoded token(s).
+        """
+        if isinstance(ids, int):
+            return self._convert_id_to_token(ids)
+        tokens = []
+        for index in ids:
+            index = int(index)
+            if skip_special_tokens and index in self.all_special_ids:
+                continue
+            token = self._tokenizer.id_to_token(index)
+            tokens.append(token if token is not None else "")
+        return tokens
+    
+    def _convert_id_to_token(self, index: int) -> Optional[str]:
+        token = self._tokenizer.id_to_token(int(index))
+        return token if token is not None else ""
diff --git a/mft_peft_hf/src/model/gpt_bigcode/__init__.py b/mftcoder_accelerate/src/model/gpt_bigcode/__init__.py
similarity index 100%
rename from mft_peft_hf/src/model/gpt_bigcode/__init__.py
rename to mftcoder_accelerate/src/model/gpt_bigcode/__init__.py
diff --git a/mft_peft_hf/src/model/gpt_bigcode/configuration_gpt_bigcode.py b/mftcoder_accelerate/src/model/gpt_bigcode/configuration_gpt_bigcode.py
similarity index 100%
rename from mft_peft_hf/src/model/gpt_bigcode/configuration_gpt_bigcode.py
rename to mftcoder_accelerate/src/model/gpt_bigcode/configuration_gpt_bigcode.py
diff --git a/mft_peft_hf/src/model/gpt_bigcode/modeling_gpt_bigcode.py b/mftcoder_accelerate/src/model/gpt_bigcode/modeling_gpt_bigcode.py
similarity index 100%
rename from mft_peft_hf/src/model/gpt_bigcode/modeling_gpt_bigcode.py
rename to mftcoder_accelerate/src/model/gpt_bigcode/modeling_gpt_bigcode.py
diff --git a/mft_atorch/model/gpt_neox/__init__.py b/mftcoder_accelerate/src/model/gpt_neox/__init__.py
similarity index 100%
rename from mft_atorch/model/gpt_neox/__init__.py
rename to mftcoder_accelerate/src/model/gpt_neox/__init__.py
diff --git a/mft_peft_hf/src/model/gpt_neox/config.json b/mftcoder_accelerate/src/model/gpt_neox/config.json
similarity index 100%
rename from mft_peft_hf/src/model/gpt_neox/config.json
rename to mftcoder_accelerate/src/model/gpt_neox/config.json
diff --git a/mft_peft_hf/src/model/gpt_neox/configuration_gpt_neox.py b/mftcoder_accelerate/src/model/gpt_neox/configuration_gpt_neox.py
similarity index 100%
rename from mft_peft_hf/src/model/gpt_neox/configuration_gpt_neox.py
rename to mftcoder_accelerate/src/model/gpt_neox/configuration_gpt_neox.py
diff --git a/mft_atorch/model/gpt_neox/generation_config.json b/mftcoder_accelerate/src/model/gpt_neox/generation_config.json
similarity index 100%
rename from mft_atorch/model/gpt_neox/generation_config.json
rename to mftcoder_accelerate/src/model/gpt_neox/generation_config.json
diff --git a/mft_peft_hf/src/model/gpt_neox/modeling_gpt_neox.py b/mftcoder_accelerate/src/model/gpt_neox/modeling_gpt_neox.py
similarity index 100%
rename from mft_peft_hf/src/model/gpt_neox/modeling_gpt_neox.py
rename to mftcoder_accelerate/src/model/gpt_neox/modeling_gpt_neox.py
diff --git a/mft_peft_hf/src/model/gpt_neox/tokenization_gpt_neox_fast.py b/mftcoder_accelerate/src/model/gpt_neox/tokenization_gpt_neox_fast.py
similarity index 100%
rename from mft_peft_hf/src/model/gpt_neox/tokenization_gpt_neox_fast.py
rename to mftcoder_accelerate/src/model/gpt_neox/tokenization_gpt_neox_fast.py
diff --git a/mftcoder_accelerate/src/model/phi/configuration_mixformer_sequential.py b/mftcoder_accelerate/src/model/phi/configuration_mixformer_sequential.py
new file mode 100644
index 0000000..8cc2d51
--- /dev/null
+++ b/mftcoder_accelerate/src/model/phi/configuration_mixformer_sequential.py
@@ -0,0 +1,53 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import math
+from typing import Any, Dict, List, Optional, Union
+
+from transformers import PretrainedConfig
+
+
+class MixFormerSequentialConfig(PretrainedConfig):
+    """MixFormer (sequential for DeepSpeed) configuration."""
+
+    model_type = "mixformer-sequential"
+
+    attribute_map = {
+        "max_position_embeddings": "n_positions",
+        "hidden_size": "n_embd",
+        "num_attention_heads": "n_head",
+        "num_hidden_layers": "n_layer",
+    }
+
+    def __init__(
+        self,
+        vocab_size: Optional[int] = 50304,
+        n_positions: Optional[int] = 2048,
+        n_embd: Optional[int] = 1024,
+        n_layer: Optional[int] = 20,
+        n_inner: Optional[int] = None,
+        n_head: Optional[int] = 16,
+        rotary_dim: Optional[int] = 32,
+        activation_function: Optional[str] = "gelu_new",
+        embd_pdrop: Optional[float] = 0.0,
+        resid_pdrop: Optional[float] = 0.0,
+        layer_norm_epsilon: Optional[float] = 1e-5,
+        initializer_range: Optional[float] = 0.02,
+        tie_word_embeddings: Optional[bool] = False,
+        pad_vocab_size_multiple: Optional[int] = 64,
+        **kwargs
+    ) -> None:
+        self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
+        self.n_positions = n_positions
+        self.n_embd = n_embd
+        self.n_layer = n_layer
+        self.n_inner = n_inner
+        self.n_head = n_head
+        self.rotary_dim = min(rotary_dim, n_embd // n_head)
+        self.activation_function = activation_function
+        self.embd_pdrop = embd_pdrop
+        self.resid_pdrop = resid_pdrop
+        self.layer_norm_epsilon = layer_norm_epsilon
+        self.initializer_range = initializer_range
+
+        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
diff --git a/mftcoder_accelerate/src/model/phi/modeling_mixformer_sequential.py b/mftcoder_accelerate/src/model/phi/modeling_mixformer_sequential.py
new file mode 100644
index 0000000..4b26a55
--- /dev/null
+++ b/mftcoder_accelerate/src/model/phi/modeling_mixformer_sequential.py
@@ -0,0 +1,764 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+#
+# BSD 3-Clause License
+# 
+# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
+# All rights reserved.
+# 
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+# 
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+# 
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+# 
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+# 
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import annotations
+
+import math
+import copy
+from typing import Any, Dict, Optional, Tuple, Union
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+
+from einops import rearrange
+from transformers.activations import ACT2FN
+from transformers import PretrainedConfig, PreTrainedModel
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from .configuration_mixformer_sequential import MixFormerSequentialConfig
+
+import xformers.ops
+
+@dataclass
+class InferenceParams:
+    """Inference parameters passed to model to efficiently calculate
+    and store context during inference.
+    Reference:
+        https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
+    Args:
+        max_sequence_len: Maximum sequence length.
+        max_batch_size: Maximum batch size.
+        sequence_len_offset: Sequence length offset.
+        batch_size_offset: Batch size offset.
+        key_value_memory_dict: Key value memory dictionary.
+        fused_ft_kernel: Whether to use fused kernel for fast inference.
+        lengths_per_sample: Lengths per sample.
+    """
+
+    max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
+
+    max_batch_size: int = field(metadata={"help": "Maximum batch size."})
+
+    sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
+
+    batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
+
+    key_value_memory_dict: Dict[str, Any] = field(
+        default_factory=dict, metadata={"help": "Key value memory dictionary."}
+    )
+
+    fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
+
+    lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
+
+
+class Embedding(nn.Module):
+    """Token embedding with dropout."""
+
+    def __init__(self, config: PretrainedConfig) -> None:
+        super().__init__()
+
+        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
+        self.drop = nn.Dropout(config.embd_pdrop)
+
+    def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
+        input_shape = input_ids.size()
+        input_ids = input_ids.view(-1, input_shape[-1])
+
+        hidden_states = self.wte(input_ids)
+        hidden_states = self.drop(hidden_states)
+
+        return hidden_states
+
+
+class RotaryEmbedding(nn.Module):
+    """Rotary embeddings.
+    Reference:
+        https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
+    
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        base: int = 10000,
+        scale_base: Optional[float] = None,
+        device: Optional[str] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__()
+
+        if scale_base is not None:
+            raise NotImplementedError
+
+        # Generate and save the inverse frequency buffer (non-trainable)
+        self.dim = dim
+        self.base = base
+        self.scale_base = scale_base
+        self.device = device
+
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+        scale = (
+            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
+            if scale_base is not None
+            else None
+        )
+        self.register_buffer("scale", scale)
+
+        self._seq_len_cached = 0
+        self._cos_cached = None
+        self._sin_cached = None
+        self._cos_k_cached = None
+        self._sin_k_cached = None
+
+    def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
+        # Reset the tables if the sequence length has changed,
+        # or if we're on a new device (possibly due to tracing for instance)
+        seqlen = x.shape[1] + seqlen_offset
+
+        # Re-generate the inverse frequency buffer if it's not fp32
+        # (for instance if model.half() was called)
+        if self.inv_freq.dtype != "torch.float32":
+            self.inv_freq = 1.0 / (
+                self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
+            )
+
+        if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
+            self._seq_len_cached = seqlen
+            t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
+
+            # Don't do einsum, it converts fp32 to fp16
+            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
+            if self.scale is None:
+                self._cos_cached = torch.cos(freqs).to(x.dtype)
+                self._sin_cached = torch.sin(freqs).to(x.dtype)
+            else:
+                power = (
+                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
+                ) / self.scale_base
+                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
+
+                # We want the multiplication by scale to happen in fp32
+                self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
+                self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
+                self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
+                self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
+
+    def _apply_rotary_emb_qkv(
+        self,
+        qkv: torch.FloatTensor,
+        sin: torch.FloatTensor,
+        cos: torch.FloatTensor,
+        sin_k: Optional[torch.FloatTensor] = None,
+        cos_k: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        _, seqlen, three, _, headdim = qkv.shape
+        assert three == 3
+
+        rotary_seqlen, rotary_dim = cos.shape
+        rotary_dim *= 2
+        assert rotary_dim <= headdim
+        assert seqlen <= rotary_seqlen
+
+        cos_k = cos if cos_k is None else cos_k
+        sin_k = sin if sin_k is None else sin_k
+        assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
+
+        q_rot = qkv[:, :, 0, :, :rotary_dim]
+        q_pass = qkv[:, :, 0, :, rotary_dim:]
+
+        k_rot = qkv[:, :, 1, :, :rotary_dim]
+        k_pass = qkv[:, :, 1, :, rotary_dim:]
+
+        # Splits the queries and keys in half
+        q1, q2 = q_rot.chunk(2, dim=-1)
+        k1, k2 = k_rot.chunk(2, dim=-1)
+        c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
+
+        # Casts to fp32 are necessary to prevent fp16 overflow issues
+        q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
+
+        # Computes the new keys and queries, recasting to original dtype
+        q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
+        k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
+
+        return torch.cat(
+            [
+                torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
+                torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
+                qkv[:, :, 2:3, :, :],
+            ],
+            axis=2,
+        )
+
+    def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+        # `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
+        self._update_cos_sin_cache(qkv, seqlen_offset)
+        return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
+
+
+class MLP(nn.Module):
+    """Multi-Layer Perceptron.
+    Reference:
+        Attention Is All You Need.
+        https://arxiv.org/pdf/1706.03762.pdf.
+    """
+
+    def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
+        super().__init__()
+
+        act_fn = config.activation_function if act_fn is None else act_fn
+        assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
+
+        n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
+        n_inner = n_inner if n_inner is not None else 4 * config.n_embd
+
+        self.fc1 = nn.Linear(config.n_embd, n_inner)
+        self.fc2 = nn.Linear(n_inner, config.n_embd)
+        self.act = ACT2FN[act_fn]
+
+    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+
+        return hidden_states
+
+
+class SelfAttention(nn.Module):
+    """Self-attention layer (compatible with PyTorch).
+    
+    Reference:
+        https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
+    """
+
+    def __init__(
+        self,
+        causal: bool = True,
+        softmax_scale: Optional[float] = None,
+        attention_dropout: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        self.causal = causal
+        self.softmax_scale = softmax_scale
+        self.drop = nn.Dropout(attention_dropout)
+
+    def forward(
+        self,
+        qkv: torch.FloatTensor,
+        causal: bool = None,
+        attention_mask: Optional[torch.BoolTensor] = None,
+        **kwargs,
+    ) -> torch.FloatTensor:
+        causal = self.causal if causal is None else causal
+        batch_size, seq_len = qkv.shape[0], qkv.shape[1]
+        q, k, v = qkv.unbind(dim=2)
+
+        # flash attention
+        output = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(),
+                                                         op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp)
+
+        # softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
+        # scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
+
+        # if attention_mask is not None:
+        #     padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
+        #     padding_mask.masked_fill_(attention_mask, 0.0)
+
+        #     scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
+
+        # if causal:
+        #     causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
+        #     scores = scores + causal_mask.to(dtype=scores.dtype)
+
+        # attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
+        # attention = self.drop(attention)
+
+        # output = torch.einsum("bhts,bshd->bthd", attention, v)
+
+        return output
+
+
+class CrossAttention(nn.Module):
+    """Cross-attention layer (compatible with PyTorch).
+    
+    Reference:
+        https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
+    
+    """
+
+    def __init__(
+        self,
+        causal: bool = True,
+        softmax_scale: Optional[float] = None,
+        attention_dropout: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        self.causal = causal
+        self.softmax_scale = softmax_scale
+        self.drop = nn.Dropout(attention_dropout)
+
+    def forward(
+        self,
+        q: torch.FloatTensor,
+        kv: torch.FloatTensor,
+        causal: bool = None,
+        attention_mask: Optional[torch.BoolTensor] = None,
+        **kwargs,
+    ) -> torch.FloatTensor:
+        causal = self.causal if causal is None else causal
+        batch_size, seq_len_q = q.shape[0], q.shape[1]
+        assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
+
+        seq_len_k = kv.shape[1]
+        k, v = kv.unbind(dim=2)
+
+        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
+        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
+
+        if attention_mask is not None:
+            padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
+            padding_mask.masked_fill_(attention_mask, 0.0)
+
+            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
+
+        if causal:
+            causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
+            scores = scores + causal_mask.to(dtype=scores.dtype)
+
+        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
+        attention = self.drop(attention)
+
+        output = torch.einsum("bhts,bshd->bthd", attention, v)
+
+        return output
+
+
+def find_mha_dims(
+    config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
+) -> Tuple[int, int]:
+    """Validate and return the number of heads and head dimension for multi-head attention.
+    Args:
+        config: Model configuration.
+        n_head: Number of heads.
+        head_dim: Head dimension.
+    Returns:
+        Number of heads and head dimension.
+    """
+
+    assert all(
+        hasattr(config, attr) for attr in ["n_embd", "n_head"]
+    ), "`config` must have `n_embd` and `n_head` attributes."
+
+    if head_dim is None:
+        assert (
+            config.n_embd % config.n_head == 0
+        ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
+
+    if n_head is None and head_dim is None:
+        head_dim = config.n_embd // config.n_head
+        n_head = config.n_head
+    elif n_head is None or head_dim is None:
+        raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
+
+    return n_head, head_dim
+
+
+def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
+    """Update the key-value cache for inference.
+    Reference:
+        https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
+    Args:
+        kv: Key-value tensor.
+        inference_params: Inference parameters.
+        layer_idx: Layer index.
+    Returns:
+        Updated key-value tensor.
+    """
+
+    num_heads, head_dim = kv.shape[-2:]
+
+    if layer_idx not in inference_params.key_value_memory_dict:
+        kv_cache = torch.empty(
+            inference_params.max_batch_size,
+            inference_params.max_sequence_len,
+            2,
+            num_heads,
+            head_dim,
+            dtype=kv.dtype,
+            device=kv.device,
+        )
+        inference_params.key_value_memory_dict[layer_idx] = kv_cache
+    else:
+        if not inference_params.fused_ft_kernel:
+            kv_cache = inference_params.key_value_memory_dict[layer_idx]
+        else:
+            k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
+            kv_cache = None
+
+    batch_start = inference_params.batch_size_offset
+    batch_end = batch_start + kv.shape[0]
+    assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
+
+    sequence_start = inference_params.sequence_len_offset
+    sequence_end = sequence_start + kv.shape[1]
+    assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
+
+    if not inference_params.fused_ft_kernel:
+        assert kv_cache is not None
+
+        kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
+        kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
+
+        return kv
+
+    assert inference_params.sequence_len_offset == 0
+    assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
+
+    packsize = 4 if kv.dtype == torch.float32 else 8
+
+    if kv_cache is not None:
+        kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
+        k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
+        v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
+        inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
+    else:
+        k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
+            kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
+        )
+        v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
+
+    return kv
+
+
+class MHA(nn.Module):
+    """Multi-head attention layer."""
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        dtype: Optional[torch.dtype] = None,
+        device: Optional[str] = None,
+        rotary_dim: Optional[int] = None,
+        rotary_emb_scale_base: Optional[float] = None,
+        n_head: Optional[int] = None,
+        head_dim: Optional[int] = None,
+        bias: bool = True,
+        causal: bool = True,
+        softmax_scale: Optional[float] = None,
+        dropout: float = 0.0,
+        layer_idx: Optional[int] = None,
+        return_residual: bool = False,
+        checkpointing: bool = False,
+    ) -> None:
+        super().__init__()
+
+        # Rotary embedding
+        self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
+        if self.rotary_emb_dim > 0:
+            rotary_kwargs = {"device": device}
+            if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
+                rotary_kwargs["scale_base"] = rotary_emb_scale_base
+            self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
+        
+        # MLP
+        self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
+        op_size = self.n_head * self.head_dim
+        hidden_size = config.n_embd
+
+        self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
+        self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
+
+        # Attention
+        self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
+        self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
+
+        self.layer_idx = layer_idx
+        self.return_residual = return_residual
+        # self.checkpointing = checkpointing
+        self.checkpointing = True
+
+    def forward(
+        self,
+        x: torch.FloatTensor,
+        past_key_values: Optional[InferenceParams] = None,
+        attention_mask: Optional[torch.BoolTensor] = None,
+        cu_seqlens: Optional[torch.LongTensor] = None,
+        max_seqlen: Optional[int] = None,
+        **kwargs,
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+        qkv = self.Wqkv(x)
+        qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
+
+        seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
+        if self.rotary_emb_dim > 0:
+            qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
+
+        if past_key_values is not None:
+            kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
+
+        if attention_mask is not None:
+            attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
+            attention_mask = attention_mask.bool().to(qkv.device)
+
+        attention_kwargs = {"attention_mask": attention_mask}
+
+        if past_key_values is None or seqlen_offset == 0:
+            if self.checkpointing:
+                # attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
+                attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv)
+            else:
+                attn_output = self.inner_attn(qkv, **attention_kwargs)
+        else:
+            q = qkv[:, :, 0]
+            causal = None if past_key_values.sequence_len_offset == 0 else False
+            attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
+
+        output = rearrange(attn_output, "... h d -> ... (h d)")
+        output = self.out_proj(output)
+
+        return output if not self.return_residual else (output, x)
+
+
+class ParallelBlock(nn.Module):
+    """Parallel block.
+    This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
+    """
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        block_idx: Optional[int] = None,
+        checkpointing: bool = False,
+    ) -> None:
+        super().__init__()
+
+        self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
+        self.resid_dropout = nn.Dropout(config.resid_pdrop)
+        self.block_idx = block_idx
+
+        self.mixer = MHA(config, layer_idx=block_idx, checkpointing=checkpointing)
+        self.mlp = MLP(config)
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
+        attention_mask: Optional[torch.BoolTensor] = None,
+        **kwargs,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+        hidden_states = self.ln(hidden_states)
+
+        attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask)
+        if isinstance(attn_outputs, tuple):
+            attn_outputs = attn_outputs[0]
+
+        attn_outputs = self.resid_dropout(attn_outputs)
+        feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
+
+        hidden_states = attn_outputs + feed_forward_hidden_states + residual
+
+        return hidden_states
+
+
+class CausalLMHead(nn.Module):
+    """Causal Language Modeling head.
+    Reference:
+        Improving Language Understanding by Generative Pre-Training.
+        https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
+    """
+
+    def __init__(self, config: PretrainedConfig) -> None:
+        super().__init__()
+
+        self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
+        self.linear = nn.Linear(config.n_embd, config.vocab_size)
+
+    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+        hidden_states = self.ln(hidden_states)
+        logits = self.linear(hidden_states).to(torch.float32)
+
+        return logits
+
+
+class CausalLMLoss(nn.Module):
+    """Causal Language Modeling loss.
+    Reference:
+        Improving Language Understanding by Generative Pre-Training.
+        https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
+    """
+
+    def __init__(self, shift_labels: bool = True) -> None:
+        super().__init__()
+
+        self.shift_labels = shift_labels
+        self.loss_fct = nn.CrossEntropyLoss()
+
+    def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
+        if self.shift_labels:
+            logits = logits[..., :-1, :].contiguous()
+            labels = labels[..., 1:].contiguous()
+
+        loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
+
+        return loss
+
+
+class MixFormerSequentialPreTrainedModel(PreTrainedModel):
+    """MixFormer (sequential for DeepSpeed) pre-trained model."""
+
+    config_class = MixFormerSequentialConfig
+    base_model_prefix = "transformer"
+    supports_gradient_checkpointing = True
+
+    def __init__(self, *inputs, **kwargs) -> None:
+        super().__init__(*inputs, **kwargs)
+
+    def _init_weights(self, module: nn.Module) -> None:
+        if isinstance(module, (nn.Linear,)):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids: torch.LongTensor,
+        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
+        attention_mask: Optional[torch.BoolTensor] = None,
+        **kwargs,
+    ) -> Dict[str, Any]:
+        if attention_mask is not None and torch.any(~attention_mask.bool()):
+            total_seq_len = torch.sum(attention_mask, dim=1)
+            max_seq_len = torch.max(total_seq_len)
+
+            total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
+            cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
+            attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
+        else:
+            attention_mask = None
+
+        if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
+            past_key_values = InferenceParams(
+                max_batch_size=input_ids.shape[0],
+                max_sequence_len=self.config.n_positions,
+                sequence_len_offset=0,
+                batch_size_offset=0,
+                fused_ft_kernel=False,
+                key_value_memory_dict={},
+            )
+        else:
+            # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
+            past_key_values.sequence_len_offset = len(input_ids[0]) - 1
+            input_ids = input_ids[:, -1].unsqueeze(-1)
+
+        return {
+            "input_ids": input_ids,
+            "past_key_values": past_key_values,
+            "attention_mask": attention_mask,
+        }
+    
+    # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
+    def _set_gradient_checkpointing(self, module, value=False):
+        if isinstance(module, PreTrainedModel):
+            module.gradient_checkpointing = value
+
+
+class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
+    """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
+
+    _keys_to_ignore_on_load_missing = [""]
+    _keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
+    _no_split_modules = ["ParallelBlock"]
+
+    def __init__(self, config: MixFormerSequentialConfig) -> None:
+        super().__init__(config)
+
+        modules = [Embedding(config)]
+        modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
+        modules.append(CausalLMHead(config))
+
+        self.layers = nn.Sequential(*modules)
+        self.loss = CausalLMLoss()
+
+        self.post_init()
+
+    def get_input_embeddings(self) -> nn.Embedding:
+        return self.layers[0].wte
+
+    def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
+        self.layers[0].wte = new_embeddings
+
+    def get_output_embeddings(self) -> nn.Linear:
+        return self.layers[-1].linear
+
+    def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
+        self.layers[-1].linear = new_embeddings
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
+        attention_mask: Optional[torch.BoolTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> CausalLMOutputWithPast:
+        if attention_mask is not None and self.training:
+            print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")
+
+        if past_key_values is None and attention_mask is None:
+            lm_logits = self.layers(input_ids)
+        else:
+            hidden_layer = self.layers[0](input_ids)
+            for module in self.layers[1:-1]:
+                hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
+            lm_logits = self.layers[-1](hidden_layer)
+
+        loss = None
+        if labels is not None:
+            loss = self.loss(lm_logits, labels)
+
+        return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp
new file mode 100644
index 0000000..8458a9b
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp
@@ -0,0 +1,198 @@
+#include 
+#include 
+#include 
+
+// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_256.cpp
+void vecquant8matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant8matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant8matmul_batched_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_column_compression_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant4matmul_batched_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant4matmul_batched_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant4matmul_batched_column_compression_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched_column_compression(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant4matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_old_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_old(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+void vecquant4matmul_batched_old_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched_old(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant4matmul_batched_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_column_compression_old_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression_old(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant4matmul_batched_column_compression_old_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched_column_compression_old(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant4matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+
+void vecquant8matmul_batched_faster_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_faster(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_faster_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+void vecquant8matmul_batched_faster_old_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_faster_old(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_faster_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_column_compression_faster_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression_faster(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_column_compression_faster_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+void vecquant8matmul_batched_column_compression_faster_old_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression_faster_old(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_batched_column_compression_faster_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant8matmul_batched", &vecquant8matmul_batched, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant8matmul_batched_old", &vecquant8matmul_batched_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant8matmul_batched_faster", &vecquant8matmul_batched_faster, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant8matmul_batched_faster_old", &vecquant8matmul_batched_faster_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant4matmul_batched_old", &vecquant4matmul_batched_old, "Vector 4-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant8matmul_batched_column_compression", &vecquant8matmul_batched_column_compression, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+  m.def("vecquant8matmul_batched_column_compression_old", &vecquant8matmul_batched_column_compression_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+  m.def("vecquant8matmul_batched_column_compression_faster", &vecquant8matmul_batched_column_compression_faster, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+  m.def("vecquant8matmul_batched_column_compression_faster_old", &vecquant8matmul_batched_column_compression_faster_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+  m.def("vecquant4matmul_batched_column_compression_old", &vecquant4matmul_batched_column_compression_old, "Vector old 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+  m.def("vecquant4matmul_batched", &vecquant4matmul_batched, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant4matmul_batched_column_compression", &vecquant4matmul_batched_column_compression, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+}
diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu
new file mode 100644
index 0000000..b7932cd
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu
@@ -0,0 +1,1708 @@
+#define _CRT_SECURE_NO_WARNINGS
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
+// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu
+__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
+    unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    do {
+        assumed = old;
+        unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff);
+        hsum += val;
+        old = reinterpret_cast(address) & 2
+                 ? (old & 0xffff) | (hsum << 16)
+                 : (old & 0xffff0000) | hsum;
+        old = atomicCAS(address_as_ui, assumed, old);
+
+    // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+    } while (assumed != old);
+}
+__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) {
+    unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    do {
+        assumed = old;
+        __half_raw hsum;
+        hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+        half tmpres = __hadd(hsum, val);
+        hsum = __half_raw(tmpres);
+        old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+        old = atomicCAS(address_as_ui, assumed, old);
+    } while (assumed != old);
+}
+#endif
+
+template 
+__global__ void VecQuant8MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const       int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+);
+
+template 
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+);
+
+template 
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+);
+
+template 
+__global__ void VecQuant8BatchMatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+);
+
+template 
+__global__ void VecQuant4BatchMatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+);
+
+
+
+template 
+__global__ void VecQuant8BatchMatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+);
+
+__global__ void VecQuant8BatchMatMulKernel_faster(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+);
+
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster_old(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width
+);
+
+
+template 
+__global__ void VecQuant4BatchMatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+);
+
+
+template 
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+);
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+);
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+);
+
+
+template 
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+);
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width
+);
+
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+);
+
+const int BLOCKWIDTH  = 128;
+const int BLOCKHEIGHT8 =  32;
+const int BLOCKHEIGHT4 =  16;
+const int BLOCKHEIGHT_OLD4 =  128;
+//const int BLOCKHEIGHT_OLD8 =  128;
+
+__device__ inline unsigned int as_unsigned(int i) {
+  return *reinterpret_cast(&i);
+}
+
+__device__ inline int as_int(int i) {
+  return *reinterpret_cast(&i);
+}
+
+void vecquant8matmul_batched_column_compression_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int height = vec.size(3);
+  int width = mat.size(3) * 4;
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant8matmul_batched_cuda", ([&] {
+      VecQuant8BatchMatMulColumnCompressionKernel<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, height, width
+      );
+    })
+  );
+
+}
+
+template 
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+) {
+  int weight_total = batch * heads * height * width / 4;
+  int input_total = batch * heads * vec_row * height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKWIDTH
+  int h = BLOCKWIDTH * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int k;
+  scalar_t w_tmp;
+
+  float weight[BLOCKWIDTH];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h + k < height; ++k){
+        int i_w = (w / 4);
+        int w_bit = (w % 4) * 8;
+
+        int w_index = (batch_shift * height + h + k) * width / 4 + i_w;
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * height + h + k];
+          scalar_t zero = zeros[batch_shift * height + h + k];
+          w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xFF);
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h + k < height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+void vecquant8matmul_batched_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int vec_height = vec.size(3);
+  int height = mat.size(2);
+  int width = mat.size(3);
+  int zero_width = zeros.size(2);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant8matmul_batched_cuda", ([&] {
+      VecQuant8BatchMatMulKernel<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, vec_height, height, width, zero_width
+      );
+    })
+  );
+
+}
+
+template 
+__global__ void VecQuant8BatchMatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+) {
+  int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * vec_height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKHEIGHT8
+  int h = BLOCKHEIGHT8 * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= vec_height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  // i is index of mat of block first row
+  int i = width * h + w;
+  // if (i >= width * height) {
+  //   return;
+  // }
+  int k;
+  scalar_t w_tmp;
+
+  int z_w = w / 4;
+  int z_mod = (w % 4) * 8;
+
+  float weight[BLOCKWIDTH];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h * 4 + k < vec_height; ++k){
+        int k_w = (k / 4);
+        int k_bit = (k % 4) * 8;
+
+        int w_index = batch_shift * height * width + i + (k_w * width);
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * width + w];
+          scalar_t zero;
+          if (zero_width == width) {
+            zero = zeros[batch_shift * width + w];
+          } else {
+            zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
+          }
+          w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xFF);
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h * 4 + k < vec_height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+void vecquant8matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant8matmul_cuda", ([&] {
+      VecQuant8MatMulKernel<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(), g_idx.data(),
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template 
+__global__ void VecQuant8MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const       int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+) {
+  int h = BLOCKHEIGHT8 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = h * 4;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+
+  int z_w = w / 4;
+  int z_mod = (w % 4) * 8;
+
+  float weight[BLOCKWIDTH];
+
+  for (k = 0; k <  BLOCKWIDTH; ++k){
+    int k_w = (k / 4);
+    int k_bit = (k % 4) * 8;
+
+      g = as_int(g_idx[g_h + k]);
+      scalar_t scale = scales[g * width + w];
+      scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
+
+      w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
+
+    weight[k] = scale * (w_tmp - zero);
+  }
+
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){
+      res = 0;
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+    for (k = 0; k <  BLOCKWIDTH; ++k){
+      res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}
+
+
+
+void vecquant4matmul_batched_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int vec_height = vec.size(3);
+  int height = mat.size(2);
+  int width = mat.size(3);
+  int zero_width = zeros.size(2);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant4matmul_batched_cuda", ([&] {
+      VecQuant4BatchMatMulKernel<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, vec_height, height, width, zero_width
+      );
+    })
+  );
+
+}
+
+template 
+__global__ void VecQuant4BatchMatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+) {
+  int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * vec_height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKHEIGHT4
+  int h = BLOCKHEIGHT4 * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= vec_height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  // i is index of mat of block first row
+  int i = width * h + w;
+  int k;
+  scalar_t w_tmp;
+
+  int z_w = w / 8;
+  int z_mod = (w % 8) * 4;
+
+  float weight[BLOCKWIDTH];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h * 8 + k < vec_height; ++k){
+        int k_w = (k / 8);
+        int k_bit = (k % 8) * 4;
+
+        int w_index = batch_shift * height * width + i + (k_w * width);
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * width + w];
+          scalar_t zero;
+          if (zero_width == width) {
+            zero = zeros[batch_shift * width + w];
+          } else {
+            zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xF));
+          }
+          w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h * 8 + k < vec_height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+
+void vecquant4matmul_batched_column_compression_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int height = vec.size(3);
+  int width = mat.size(3) * 8;
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant4matmul_batched_cuda", ([&] {
+      VecQuant4BatchMatMulColumnCompressionKernel<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, height, width
+      );
+    })
+  );
+
+}
+
+template 
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+) {
+  int weight_total = batch * heads * height * width / 8;
+  int input_total = batch * heads * vec_row * height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKWIDTH
+  int h = BLOCKWIDTH * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int k;
+  scalar_t w_tmp;
+
+  float weight[BLOCKWIDTH];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h + k < height; ++k){
+        int i_w = (w / 8);
+        int w_bit = (w % 8) * 4;
+
+        int w_index = (batch_shift * height + h + k) * width / 8 + i_w;
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * height + h + k];
+          scalar_t zero = zeros[batch_shift * height + h + k];
+          w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xF);
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h + k < height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+void vecquant8matmul_batched_old_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int vec_height = vec.size(3);
+  int height = mat.size(2);
+  int width = mat.size(3);
+  int zero_width = zeros.size(2);
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant8matmul_batched_old_cuda", ([&] {
+      VecQuant8BatchMatMulKernel_old<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+
+template 
+__global__ void VecQuant8BatchMatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+) {
+  int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * vec_height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKHEIGHT8
+  int h = BLOCKWIDTH * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= vec_height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  // i is index of mat of block first row
+  int i = width * h + w;
+  int k;
+  scalar_t w_tmp;
+
+  float weight[BLOCKWIDTH];
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h + k < vec_height; ++k){
+        int k_w = k;
+        int w_index = batch_shift * height * width + i + (k_w * width);
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * width + w];
+          scalar_t zero = zeros[batch_shift * width + w];
+          w_tmp = as_unsigned(mat[w_index]);
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h + k < vec_height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+
+void vecquant8matmul_batched_faster_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int vec_height = vec.size(3);
+  int height = mat.size(2);
+  int width = mat.size(3);
+  int zero_width = zeros.size(2);
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  VecQuant8BatchMatMulKernel_faster<<>>(
+    (half*) vec.data_ptr(),
+    (uint8_t*) mat.data_ptr(),
+    (half*) mul.data_ptr(),
+    (half*) scales.data_ptr(),
+    (half*) zeros.data_ptr(),
+    batch, heads, vec_row, vec_height, height, width, zero_width
+  );
+}
+
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+) {
+  //int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * vec_height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  int h = BLOCKWIDTH * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= height) {
+    return;
+  }
+
+  __shared__ float blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int k;
+  float w_tmp;
+
+  float weight[BLOCKWIDTH];
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h + k < vec_height; ++k){
+        int k_w = k;
+        int w_index = batch_shift * height * width + i + (k_w * width);
+        float scale = __half2float(scales[batch_shift * width + w]);
+        float zero = __half2float(zeros[batch_shift * width + w]);
+        w_tmp = as_unsigned(mat[w_index]);
+        weight[k] = scale *(w_tmp-zero);
+      }
+
+      float res;
+      for (int vr = 0; vr < vec_row; ++vr){
+        res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = __half2float(vec[vec_index]);
+        } else {
+            blockvec[tid] = 0;
+        }
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h + k < vec_height; ++k){
+            float temp_res = weight[k]*blockvec[k];
+            res += temp_res;
+        }
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], __float2half(res));
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+
+
+void vecquant8matmul_batched_column_compression_faster_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int height = vec.size(3);
+  int width = mat.size(3);
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  VecQuant8BatchMatMulColumnCompressionKernel_faster<<>>(
+    (half*) vec.data_ptr(),
+    (uint8_t*) mat.data_ptr(),
+    (half*) mul.data_ptr(),
+    (half*) scales.data_ptr(),
+    (half*) zeros.data_ptr(),
+    batch, heads, vec_row, height, width
+  );
+
+}
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+) {
+  //int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  int h = BLOCKWIDTH * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= height) {
+    return;
+  }
+
+  __shared__ float blockvec[BLOCKWIDTH];
+  int k;
+  float w_tmp;
+  float weight[BLOCKWIDTH];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH; ++k){
+        int w_index = (batch_shift * height + h + k) * width  + w;
+        float scale = __half2float(scales[batch_shift * height + h + k]);
+        float zero = __half2float(zeros[batch_shift * height + h + k]);
+        w_tmp = mat[w_index];
+        weight[k] = scale * (w_tmp-zero);
+      }
+
+      float res;
+      for (int vr = 0; vr < vec_row; ++vr){
+        res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = __half2float(vec[vec_index]);
+        } else {
+            blockvec[tid] = 0;
+        }
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH; ++k){
+            res += weight[k]*blockvec[k];
+        }
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], __float2half(res));
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+
+void vecquant8matmul_batched_column_compression_old_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int height = vec.size(3);
+  int width = mat.size(3);
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant8matmul_batched_column_compression_old_cuda", ([&] {
+      VecQuant8BatchMatMulColumnCompressionKernel_old<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, height, width
+      );
+    })
+  );
+
+}
+
+template 
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+) {
+  int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKWIDTH
+  int h = BLOCKWIDTH * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int k;
+  scalar_t w_tmp;
+
+  float weight[BLOCKWIDTH];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h + k < height; ++k){
+        int w_index = (batch_shift * height + h + k) * width  + w;
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * height + h + k];
+          scalar_t zero = zeros[batch_shift * height + h + k];
+          w_tmp = mat[w_index];
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h + k < height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+void vecquant4matmul_batched_old_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int vec_height = vec.size(3);
+  int height = mat.size(2);
+  int width = mat.size(3);
+  int zero_width = zeros.size(2);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant4matmul_batched_old_cuda", ([&] {
+      VecQuant4BatchMatMulKernel_old<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, vec_height, height, width, zero_width
+      );
+    })
+  );
+
+}
+
+template 
+__global__ void VecQuant4BatchMatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width
+) {
+  int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * vec_height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKHEIGHT_OLD4
+  int h = BLOCKHEIGHT_OLD4 * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= vec_height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  // i is index of mat of block first row
+  int i = width * h + w;
+  int k;
+  scalar_t w_tmp;
+
+  float weight[BLOCKWIDTH];
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h*2 + k < vec_height; ++k){
+        int k_w = (k / 2);
+        int k_bit = (k % 2) * 4;
+        int w_index = batch_shift * height * width + i + (k_w * width);
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * width + w];
+          scalar_t zero = zeros[batch_shift * width + w];
+          w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h*2 + k < vec_height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+
+
+
+void vecquant4matmul_batched_column_compression_old_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int height = vec.size(3);
+  int width = mat.size(3);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant4matmul_batched_column_compression_old_cuda", ([&] {
+      VecQuant4BatchMatMulColumnCompressionKernel_old<<>>(
+        vec.data(), mat.data(), mul.data(),
+        scales.data(), zeros.data(),
+        batch, heads, vec_row, height, width
+      );
+    })
+  );
+
+}
+
+template 
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  scalar_t* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int height,
+    int width
+) {
+  int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  // h is index of height with step being BLOCKWIDTH
+  int h = BLOCKHEIGHT_OLD4 * blockIdx.x;
+  // w is index of width with step being 1
+  int w = BLOCKWIDTH * blockIdx.y + tid;
+  if (w >= width && tid >= height) {
+    return;
+  }
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int k;
+  scalar_t w_tmp;
+
+  float weight[BLOCKWIDTH];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      for (k = 0; k <  BLOCKWIDTH && h*2 + k < height; ++k){
+        int k_w = (k / 2);
+        int k_bit = (k % 2) * 4;
+        int w_index = (batch_shift * height + h + k) * width  + k_w;
+        if (w_index >= weight_total || w >= width) {
+          weight[k] = 0;
+        } else {
+          scalar_t scale = scales[batch_shift * height + h + k];
+          scalar_t zero = zeros[batch_shift * height + h + k];
+          w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
+          weight[k] = scale * (w_tmp - zero);
+        }
+      }
+
+      scalar_t res;
+      for (int vr = 0; vr < vec_row; ++vr){
+          res = 0;
+        int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+        if (vec_index < input_total) {
+            blockvec[tid] = vec[vec_index];
+        } else {
+            blockvec[tid] = 0;
+        }
+
+        __syncthreads();
+          for (k = 0; k <  BLOCKWIDTH && h*2 + k < height; ++k){
+          // res is the dot product of BLOCKWIDTH elements (part of width)
+            res += weight[k] * blockvec[k];
+        }
+        // add res to the final result, final matrix shape: (batch, vec_row, width)
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (out_index < out_total) {
+            atomicAdd(&mul[out_index], res);
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+
+
+
+void vecquant8matmul_batched_faster_old_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2);
+  int vec_height = vec.size(3);
+  int height = mat.size(2);
+  int width = mat.size(3);
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  VecQuant8BatchMatMulKernel_faster_old<<>>(
+    (half*) vec.data_ptr(),
+    (uint8_t*) mat.data_ptr(),
+    (half*) mul.data_ptr(),
+    (half*) scales.data_ptr(),
+    (half*) zeros.data_ptr(),
+    batch, heads, vec_row, vec_height, height, width
+  );
+}
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster_old(
+    const  half* __restrict__ vec,
+    const  uint8_t* __restrict__ mat,
+           half* __restrict__ mul,
+    const  half* __restrict__ scales,
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row,
+    int vec_height,
+    int height,
+    int width
+) {
+ int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * vec_height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  const int BLOCKWIDTH_half = BLOCKWIDTH/2;
+
+  int h = BLOCKWIDTH * blockIdx.x; //head_dim, dim=-1
+  int w = BLOCKWIDTH * blockIdx.y + tid; //seq-len, +0-256 ,dim=-2
+  /*
+  if (w >= width && tid >= vec_height) {
+    return;
+  }
+  */
+  __shared__ half blockvec[BLOCKWIDTH]; //256
+  int i = width * h + w;
+  int k;
+
+  half w_tmp1 = __float2half(0);
+  half w_tmp2 = __float2half(0);
+
+  half2 weight[BLOCKWIDTH_half];
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      //int zero_index = batch_shift;
+      for (k = 0; k <  BLOCKWIDTH_half; ++k){
+        int w_index1 = batch_shift * height * width + i + (2 * k * width); // [batch,head,h+k, w]
+        int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width);
+        int zero_index = batch_shift * width + w; // [batch,head, w]
+        if (w_index1 >= weight_total || w >= width || (2 * k + h) >= height) {
+          weight[k] = __float2half2_rn(0);
+        } else {
+            float zero_f=__half2float(zeros[zero_index]);
+            float scale_f= __half2float(scales[zero_index]);
+            if (w_index2 >= weight_total){
+              w_tmp1 = __float2half((as_unsigned(mat[w_index1]) -zero_f)*scale_f);
+              w_tmp2 = __float2half(0);
+              weight[k] = __halves2half2(w_tmp1,w_tmp2);
+              //printf("zero_index is %d w is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,w,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k]));
+            }else{
+              w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1]));
+              w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2]));
+
+              //weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero,zero)),__halves2half2(scale,scale));
+              weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f)));
+              //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k]));
+            }
+        }
+      }
+
+
+      for (int vr = 0; vr < vec_row; ++vr){
+        float res=0;
+        int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+        if (vec_index < input_total) {
+            //blockvec[tid] = __half2float(vec[vec_index]);// [batch, head, vr, tid(seq_len dim+)]
+            blockvec[tid] = vec[vec_index];
+            //printf("width is %d height is %d h is %d w is %d vec_index is %d out_index is %d vec_row is %d vec_height is %d,vr is %d tid is %d blockvec is %f\n",width,height, h,w,vec_index,out_index,vec_row,vec_height,vr,tid,blockvec[tid]);
+        } else {
+            blockvec[tid] = __float2half(0);
+        }
+        __syncthreads();
+        if (out_index < out_total) {
+          for (k = 0; k <  BLOCKWIDTH_half; ++k){
+            half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1]));
+            res += __low2float(res2) + __high2float(res2);
+          }
+          atomicAdd(&mul[out_index], __float2half(res));
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
+
+
+void vecquant8matmul_batched_column_compression_faster_old_cuda(
+  torch::Tensor vec,  // [batch,heads, seq_q, seq_v]
+  torch::Tensor mat, // [batch,heads, seq_v, head_dim]
+  torch::Tensor mul,  // [batch,heads, seq_q,head_dim]
+  torch::Tensor scales, // [batch,heads, head_dim]
+  torch::Tensor zeros
+) {
+  int batch = vec.size(0);
+  int heads = vec.size(1);
+  int vec_row = vec.size(2); //ql
+  int height = mat.size(2); //vl
+  int width = mat.size(3); //head_dim
+
+  dim3 blocks(
+    (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  VecQuant8BatchMatMulColumnCompressionKernel_faster_old<<>>(
+    (half*) vec.data_ptr(),
+    (uint8_t*) mat.data_ptr(),
+    (half*) mul.data_ptr(),
+    (half*) scales.data_ptr(),
+    (half*) zeros.data_ptr(),
+    batch, heads, vec_row, height, width
+  );
+
+}
+
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old(
+    const  half* __restrict__ vec,  // [batch,heads, seq_q, seq_v]
+    const  uint8_t* __restrict__ mat, // [batch,heads, seq_v, head_dim]
+           half* __restrict__ mul, // [batch,heads, seq_q,head_dim]
+    const  half* __restrict__ scales, // [batch,heads, seq_v]
+    const  half* __restrict__ zeros,
+    int batch,
+    int heads,
+    int vec_row, //seq_q
+    int height, //seq_v
+    int width //head_dim
+) {
+  int weight_total = batch * heads * height * width;
+  int input_total = batch * heads * vec_row * height;
+  int out_total = batch * heads * vec_row * width;
+  int tid = threadIdx.x;
+  int h = BLOCKWIDTH * blockIdx.x; // vl
+  int w = BLOCKWIDTH * blockIdx.y + tid; //head_dim + block
+  if (w >= width && tid >= height) {
+    return;
+  }
+  __shared__ half blockvec[BLOCKWIDTH];
+  int k;
+  half w_tmp1 = __float2half(0);
+  half w_tmp2 = __float2half(0);
+  int i = width * h + w;
+  const int BLOCKWIDTH_half = BLOCKWIDTH/2;
+  half2 weight[BLOCKWIDTH_half];
+
+  for (int b = 0; b < batch; ++b){
+    for (int head = 0; head < heads; ++head){
+      int batch_shift = b * heads + head;
+      //int zero_index = batch_shift;
+      for (k = 0; k <  BLOCKWIDTH_half; ++k){
+        int w_index1 = batch_shift * height * width + i + (2 * k) * width; // [batch,head, h+k, w]
+        int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width);
+        int zero_index1 = batch_shift * height + h + 2*k; // [batch,head, w]
+        int zero_index2 = batch_shift * height + h + 2*k+1; // [batch,head, w]
+
+        if (w_index1 >= weight_total || (2 * k + h)>=height) {
+          weight[k]=__float2half2_rn(0);
+        } else{
+            //int zero_index = batch_shift + h; // [batch,head, w]
+            //float scale_f1 = __half2float(scales[zero_index1]);
+            //float zero_f1 =  __half2float(zeros[zero_index1]);
+            if (w_index2>=weight_total){
+              w_tmp1 = __float2half((as_unsigned(mat[w_index1]) - __half2float(zeros[zero_index1]))* __half2float(scales[zero_index1]));
+              w_tmp2 = __float2half(0);
+              weight[k] = __halves2half2(w_tmp1,w_tmp2);
+              //printf("zero_index is %d k is %d w is %d head is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,k,w,head,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k]));
+            }else{
+              w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1]));
+              w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2]));
+              half zero1=zeros[zero_index1];
+              half zero2=zeros[zero_index2];
+              half scale1=scales[zero_index1];
+              half scale2=scales[zero_index2];
+              weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero1,zero2)),__halves2half2(scale1,scale2));
+              //weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f)));
+              //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k]));
+            }
+          }
+       }
+
+
+      for (int vr = 0; vr < vec_row; ++vr){
+        float res=0;
+        int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+        int out_index = (batch_shift * vec_row + vr) * width + w;
+
+        if (vec_index < input_total) {
+            //blockvec[tid] = __half2float(vec[vec_index]);
+            blockvec[tid] = vec[vec_index];
+            //printf("vec_index is %d out_index is %d vec_row is %d ,vr is %d tid is %d blockvec is %f\n",vec_index,out_index,vec_row,vr,tid,blockvec[tid]);
+        } else {
+            blockvec[tid] = __float2half(0);
+            //blockvec[tid] = 0;
+        }
+        __syncthreads();
+        if (out_index < out_total) {
+            for (k = 0; k <  BLOCKWIDTH_half; ++k){
+                half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1]));
+                res += __low2float(res2) + __high2float(res2);
+            }
+            atomicAdd(&mul[out_index], __float2half(res));
+        }
+        __syncthreads();
+      }
+    }
+  }
+}
diff --git a/mft_peft_hf/src/model/qwen/configuration_qwen.py b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py
similarity index 53%
rename from mft_peft_hf/src/model/qwen/configuration_qwen.py
rename to mftcoder_accelerate/src/model/qwen/configuration_qwen.py
index 502ef6c..f8fe2cb 100644
--- a/mft_peft_hf/src/model/qwen/configuration_qwen.py
+++ b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py
@@ -9,61 +9,49 @@
 class QWenConfig(PretrainedConfig):
     model_type = "qwen"
     keys_to_ignore_at_inference = ["past_key_values"]
-    attribute_map = {
-        "hidden_size": "n_embd",
-        "num_attention_heads": "n_head",
-        "max_position_embeddings": "n_positions",
-        "num_hidden_layers": "n_layer",
-    }
 
     def __init__(
         self,
-        vocab_size=151851,
-        n_embd=4096,
-        n_layer=32,
-        n_head=32,
-        n_inner=None,
-        embd_pdrop=0.0,
-        attn_pdrop=0.0,
-        layer_norm_epsilon=1e-5,
+        vocab_size=151936,
+        hidden_size=4096,
+        num_hidden_layers=32,
+        num_attention_heads=32,
+        emb_dropout_prob=0.0,
+        attn_dropout_prob=0.0,
+        layer_norm_epsilon=1e-6,
         initializer_range=0.02,
+        max_position_embeddings=8192,
         scale_attn_weights=True,
         use_cache=True,
-        eos_token_id=151643,
-        apply_residual_connection_post_layernorm=False,
         bf16=False,
         fp16=False,
         fp32=False,
         kv_channels=128,
         rotary_pct=1.0,
         rotary_emb_base=10000,
-        use_dynamic_ntk=False,
-        use_logn_attn=False,
-        use_flash_attn=True,
-        ffn_hidden_size=22016,
+        use_dynamic_ntk=True,
+        use_logn_attn=True,
+        use_flash_attn="auto",
+        intermediate_size=22016,
         no_bias=True,
         tie_word_embeddings=False,
+        use_cache_quantization=False,
+        use_cache_kernel=False,
+        softmax_in_fp32=False,
         **kwargs,
     ):
-        self.eos_token_id = eos_token_id
-        super().__init__(
-            eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
-        )
-
         self.vocab_size = vocab_size
-        self.n_embd = n_embd
-        self.n_layer = n_layer
-        self.n_head = n_head
-        self.n_inner = n_inner
-        self.embd_pdrop = embd_pdrop
-        self.attn_pdrop = attn_pdrop
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.emb_dropout_prob = emb_dropout_prob
+        self.attn_dropout_prob = attn_dropout_prob
         self.layer_norm_epsilon = layer_norm_epsilon
         self.initializer_range = initializer_range
         self.scale_attn_weights = scale_attn_weights
         self.use_cache = use_cache
-        self.apply_residual_connection_post_layernorm = (
-            apply_residual_connection_post_layernorm
-        )
+        self.max_position_embeddings = max_position_embeddings
         self.bf16 = bf16
         self.fp16 = fp16
         self.fp32 = fp32
@@ -73,6 +61,11 @@ def __init__(
         self.use_dynamic_ntk = use_dynamic_ntk
         self.use_logn_attn = use_logn_attn
         self.use_flash_attn = use_flash_attn
-        self.ffn_hidden_size = ffn_hidden_size
         self.no_bias = no_bias
-        self.tie_word_embeddings = tie_word_embeddings
+        self.use_cache_quantization = use_cache_quantization
+        self.use_cache_kernel = use_cache_kernel
+        self.softmax_in_fp32 = softmax_in_fp32
+        super().__init__(
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs
+        )
diff --git a/mftcoder_accelerate/src/model/qwen/cpp_kernels.py b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py
new file mode 100644
index 0000000..d9cee70
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py
@@ -0,0 +1,55 @@
+from torch.utils import cpp_extension
+import pathlib
+import os
+import subprocess
+
+def _get_cuda_bare_metal_version(cuda_dir):
+    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
+                                         universal_newlines=True)
+    output = raw_output.split()
+    release_idx = output.index("release") + 1
+    release = output[release_idx].split(".")
+    bare_metal_major = release[0]
+    bare_metal_minor = release[1][0]
+
+    return raw_output, bare_metal_major, bare_metal_minor
+
+def _create_build_dir(buildpath):
+    try:
+        os.mkdir(buildpath)
+    except OSError:
+        if not os.path.isdir(buildpath):
+            print(f"Creation of the build directory {buildpath} failed")
+
+# Check if cuda 11 is installed for compute capability 8.0
+cc_flag = []
+_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
+if int(bare_metal_major) >= 11:
+    cc_flag.append('-gencode')
+    cc_flag.append('arch=compute_80,code=sm_80')
+    if int(bare_metal_minor) >= 7:
+        cc_flag.append('-gencode')
+        cc_flag.append('arch=compute_90,code=sm_90')
+
+# Build path
+srcpath = pathlib.Path(__file__).parent.absolute()
+buildpath = srcpath / 'build'
+_create_build_dir(buildpath)
+
+def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
+    return cpp_extension.load(
+        name=name,
+        sources=sources,
+        build_directory=buildpath,
+        extra_cflags=['-O3', ],
+        extra_cuda_cflags=['-O3',
+                           '-gencode', 'arch=compute_70,code=sm_70',
+                           '--use_fast_math'] + extra_cuda_flags + cc_flag,
+        verbose=1
+    )
+
+extra_flags = []
+
+cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp",
+           "./cache_autogptq_cuda_kernel_256.cu"]
+cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
diff --git a/mft_peft_hf/src/model/qwen/modeling_qwen.py b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py
similarity index 62%
rename from mft_peft_hf/src/model/qwen/modeling_qwen.py
rename to mftcoder_accelerate/src/model/qwen/modeling_qwen.py
index 238fe91..45c0d16 100644
--- a/mft_peft_hf/src/model/qwen/modeling_qwen.py
+++ b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py
@@ -3,14 +3,16 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import copy
 import importlib
 import math
+import pathlib
 from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
 
 import torch
 import torch.nn.functional as F
 import torch.utils.checkpoint
-from torch.cuda.amp import autocast
+import warnings
 
 from torch.nn import CrossEntropyLoss
 from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
@@ -35,6 +37,8 @@
 SUPPORT_CUDA = torch.cuda.is_available()
 SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
 SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
+SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
+
 
 from .configuration_qwen import QWenConfig
 from .qwen_generation_utils import (
@@ -66,13 +70,18 @@
 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
 """
 
+_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\
+We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained).
+检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。
+"""
+
 apply_rotary_emb_func = None
 rms_norm = None
 flash_attn_unpadded_func = None
-
+flash_attn_func = None
 
 def _import_flash_attn():
-    global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
+    global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func
     try:
         from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
         apply_rotary_emb_func = __apply_rotary_emb_func
@@ -93,20 +102,49 @@ def _import_flash_attn():
 
     try:
         import flash_attn
+        _flash_attn_func = None
         if not hasattr(flash_attn, '__version__'):
             from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
         else:
             if int(flash_attn.__version__.split(".")[0]) >= 2:
+                if int(flash_attn.__version__.split(".")[1]) >= 1:
+                    from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
                 from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
             else:
                 from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
         flash_attn_unpadded_func = __flash_attn_unpadded_func
+        flash_attn_func = _flash_attn_func
     except ImportError:
         logger.warn(
             "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
             "https://github.com/Dao-AILab/flash-attention"
         )
 
+def quantize_cache_v(fdata, bits, qmax, qmin):
+    # b, s, head, h-dim->b, head, s, h-dim
+    qtype = torch.uint8
+    device = fdata.device
+    shape = fdata.shape
+
+    fdata_cal = torch.flatten(fdata, 2)
+    fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
+    fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
+    # Compute params
+    if qmax.device != fmax.device:
+        qmax = qmax.to(device)
+        qmin = qmin.to(device)
+    scale = (fmax - fmin) / (qmax - qmin)
+    zero = qmin - fmin / scale
+    scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
+    zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
+    # Quantize
+    res_data = fdata / scale + zero
+    qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
+    return qdata.contiguous(), scale, zero
+
+def dequantize_cache_torch(qdata, scale, zero):
+    data = scale * (qdata - zero)
+    return data
 
 class FlashSelfAttention(torch.nn.Module):
     def __init__(
@@ -126,11 +164,33 @@ def __init__(
         self.softmax_scale = softmax_scale
         self.dropout_p = attention_dropout
 
-    def forward(self, q, k, v):
+    def unpad_input(self, hidden_states, attention_mask):
+        valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)
+        seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
+        indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
+        max_seqlen_in_batch = seqlens_in_batch.max().item()
+        cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+        hidden_states = hidden_states[indices]
+        return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
+
+    def pad_input(self, hidden_states, indices, batch, seqlen):
+        output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
+                             dtype=hidden_states.dtype)
+        output[indices] = hidden_states
+        return rearrange(output, '(b s) ... -> b s ...', b=batch)
+
+    def forward(self, q, k, v, attention_mask=None):
         assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
         assert all((i.is_cuda for i in (q, k, v)))
         batch_size, seqlen_q = q.shape[0], q.shape[1]
         seqlen_k = k.shape[1]
+        seqlen_out = seqlen_q
+
+        if flash_attn_func is not None and batch_size == 1:
+            dropout_p = self.dropout_p if self.training else 0
+            output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
+            return output
+
         q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
         cu_seqlens_q = torch.arange(
             0,
@@ -140,13 +200,14 @@ def forward(self, q, k, v):
             device=q.device,
         )
 
-        if self.training:
-            assert seqlen_k == seqlen_q
-
-            is_causal = self.causal
-            cu_seqlens_k = cu_seqlens_q
+        if batch_size > 1 and attention_mask is not None:
+            k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
+            if q.size(0) == v.size(0):
+                q = q[indices_k]
+                cu_seqlens_q = cu_seqlens_k
+                seqlen_q = seqlen_k
+            v = v[indices_k]
         else:
-            is_causal = seqlen_q == seqlen_k
             cu_seqlens_k = torch.arange(
                 0,
                 (batch_size + 1) * seqlen_k,
@@ -154,7 +215,15 @@ def forward(self, q, k, v):
                 dtype=torch.int32,
                 device=q.device,
             )
-            self.dropout_p = 0
+
+        if self.training:
+            assert seqlen_k == seqlen_q
+            is_causal = self.causal
+            dropout_p = self.dropout_p
+        else:
+            is_causal = seqlen_q == seqlen_k
+            dropout_p = 0
+
         output = flash_attn_unpadded_func(
             q,
             k,
@@ -163,30 +232,23 @@ def forward(self, q, k, v):
             cu_seqlens_k,
             seqlen_q,
             seqlen_k,
-            self.dropout_p,
+            dropout_p,
             softmax_scale=self.softmax_scale,
             causal=is_causal,
         )
-
-        output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
+        if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
+            output = self.pad_input(output, indices_k, batch_size, seqlen_out)
+        else:
+            new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
+            output = output.view(new_shape)
         return output
 
 
 class QWenAttention(nn.Module):
-    def __init__(self, config, layer_number=None):
+    def __init__(self, config):
         super().__init__()
 
-        max_positions = config.max_position_embeddings
-        self.register_buffer(
-            "bias",
-            torch.tril(
-                torch.ones((max_positions, max_positions), dtype=torch.bool)
-            ).view(1, 1, max_positions, max_positions),
-            persistent=False,
-        )
         self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
-        self.layer_number = max(1, layer_number)
-        self.params_dtype = config.params_dtype
         self.seq_length = config.seq_length
 
         self.hidden_size = config.hidden_size
@@ -197,8 +259,6 @@ def __init__(self, config, layer_number=None):
         self.use_flash_attn = config.use_flash_attn
         self.scale_attn_weights = True
 
-        self.layer_idx = None
-
         self.projection_size = config.kv_channels * config.num_attention_heads
 
         assert self.projection_size % config.num_attention_heads == 0
@@ -219,25 +279,10 @@ def __init__(self, config, layer_number=None):
             and not self.is_fp32
         ):
             self.core_attention_flash = FlashSelfAttention(
-                causal=True, attention_dropout=config.attn_pdrop
+                causal=True, attention_dropout=config.attn_dropout_prob
             )
-
         self.bf16 = config.bf16
 
-        if config.rotary_pct == 1.0:
-            self.rotary_ndims = None
-        else:
-            assert config.rotary_pct < 1
-            self.rotary_ndims = int(
-                self.hidden_size_per_attention_head * config.rotary_pct
-            )
-        dim = (
-            self.rotary_ndims
-            if self.rotary_ndims is not None
-            else self.hidden_size_per_attention_head
-        )
-        self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
-
         self.use_dynamic_ntk = config.use_dynamic_ntk
         self.use_logn_attn = config.use_logn_attn
 
@@ -245,104 +290,104 @@ def __init__(self, config, layer_number=None):
             math.log(i, self.seq_length) if i > self.seq_length else 1
             for i in range(1, 32768)
         ]
-        self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
-        self._ntk_cached = 1.0
-
-        self.attn_dropout = nn.Dropout(config.attn_pdrop)
-
-    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
-        attn_weights = torch.matmul(query, key.transpose(-1, -2))
+        logn_tensor = torch.tensor(logn_list)[None, :, None, None]
+        self.register_buffer("logn_tensor", logn_tensor, persistent=False)
+
+        self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
+        self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
+        self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
+        self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
+        cache_dtype = torch.float
+        if self.bf16:
+            cache_dtype=torch.bfloat16
+        elif config.fp16:
+            cache_dtype = torch.float16
+        self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
+        self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
+
+        if config.use_cache_quantization and config.use_cache_kernel:
+            # pre check if the support files existing
+            module_root = pathlib.Path(__file__).parent
+            src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
+            if any(not (module_root/src).is_file() for src in src_files):
+                warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
+                self.cache_kernels = None
+            else:
+                try:
+                    from .cpp_kernels import cache_autogptq_cuda_256
+                    self.cache_kernels = cache_autogptq_cuda_256
+                except ImportError:
+                    warnings.warn("Failed to import KV cache kernels.")
+                    self.cache_kernels = None
+
+    def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
+        device = query.device
+        if self.use_cache_quantization:
+            qk, qk_scale, qk_zero = key
+            if self.use_cache_kernel and self.cache_kernels is not None:
+                shape = query.shape[:-1] + (qk.shape[-2],)
+                attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
+                self.cache_kernels.vecquant8matmul_batched_faster_old(
+                    query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
+                    qk.transpose(-1, -2).contiguous(),
+                    attn_weights,
+                    qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
+                    qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
+                # attn_weights = attn_weights.to(query.dtype).contiguous()
+            else:
+                key = dequantize_cache_torch(qk, qk_scale, qk_zero)
+                attn_weights = torch.matmul(query, key.transpose(-1, -2))
+        else:
+            attn_weights = torch.matmul(query, key.transpose(-1, -2))
 
         if self.scale_attn_weights:
-            attn_weights = attn_weights / torch.full(
-                [],
-                value.size(-1) ** 0.5,
-                dtype=attn_weights.dtype,
-                device=attn_weights.device,
-            )
+            if self.use_cache_quantization:
+                size_temp = value[0].size(-1)
+            else:
+                size_temp = value.size(-1)
+            attn_weights = attn_weights / (size_temp ** 0.5)
 
-        query_length, key_length = query.size(-2), key.size(-2)
-        causal_mask = self.bias[
-            :, :, key_length - query_length : key_length, :key_length
-        ]
         mask_value = torch.finfo(attn_weights.dtype).min
-        mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
-            attn_weights.device
-        )
-        attn_weights = torch.where(
-            causal_mask, attn_weights.to(attn_weights.dtype), mask_value
-        )
-
-        if attention_mask is not None:
-            # Apply the attention mask
-            attn_weights = attn_weights + attention_mask
-
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
-        attn_weights = attn_weights.type(value.dtype)
-        attn_weights = self.attn_dropout(attn_weights)
-
-        if head_mask is not None:
-            attn_weights = attn_weights * head_mask
-
-        attn_output = torch.matmul(attn_weights, value)
-        attn_output = attn_output.transpose(1, 2)
-
-        return attn_output, attn_weights
-
-    def _upcast_and_reordered_attn(
-        self, query, key, value, attention_mask=None, head_mask=None
-    ):
-        bsz, num_heads, q_seq_len, dk = query.size()
-        _, _, k_seq_len, _ = key.size()
-
-        attn_weights = torch.empty(
-            bsz * num_heads,
-            q_seq_len,
-            k_seq_len,
-            dtype=torch.float32,
-            device=query.device,
-        )
-
-        scale_factor = 1.0
-        if self.scale_attn_weights:
-            scale_factor /= float(value.size(-1)) ** 0.5
-
-        with autocast(enabled=False):
-            q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
-                -1, dk, k_seq_len
-            )
-            attn_weights = torch.baddbmm(
-                attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
+        if causal_mask is not None:
+            attn_weights = torch.where(
+                causal_mask, attn_weights.to(attn_weights.dtype), mask_value
             )
-            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
-
-        query_length, key_length = query.size(-2), key.size(-2)
-        causal_mask = self.bias[
-            :, :, key_length - query_length : key_length, :key_length
-        ]
-        mask_value = torch.finfo(attn_weights.dtype).min
-        mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
-            attn_weights.device
-        )
-        attn_weights = torch.where(causal_mask, attn_weights, mask_value)
 
         if attention_mask is not None:
             attn_weights = attn_weights + attention_mask
 
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+        if self.softmax_in_fp32:
+            attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
+        else:
+            attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 
-        if attn_weights.dtype != torch.float32:
-            raise RuntimeError(
-                "Error with upcasting, attn_weights does not have dtype torch.float32"
-            )
-        attn_weights = attn_weights.type(value.dtype)
+        attn_weights = attn_weights.type(query.dtype)
         attn_weights = self.attn_dropout(attn_weights)
 
         if head_mask is not None:
             attn_weights = attn_weights * head_mask
 
-        attn_output = torch.matmul(attn_weights, value)
+        if self.use_cache_quantization:
+            qv, qv_scale, qv_zero = value
+            if self.use_cache_kernel and self.cache_kernels is not None:
+                shape = attn_weights.shape[:-1] + (query.shape[-1],)
+                attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
+                self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
+                    attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
+                    qv.contiguous(),  # dtype: int32
+                    attn_output,
+                    qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
+                    qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
+                if attn_output.dtype != query.dtype:
+                    attn_output = attn_output.to(query.dtype)
+                    attn_weights = attn_weights.to(query.dtype)
+            else:
+                value = dequantize_cache_torch(qv, qv_scale, qv_zero)
+                attn_output = torch.matmul(attn_weights, value)
+        else:
+            attn_output = torch.matmul(attn_weights, value)
+
+        attn_output = attn_output.transpose(1, 2)
 
         return attn_output, attn_weights
 
@@ -359,6 +404,7 @@ def _merge_heads(self, tensor, num_heads, attn_head_size):
     def forward(
         self,
         hidden_states: Optional[Tuple[torch.FloatTensor]],
+        rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
         layer_past: Optional[Tuple[torch.Tensor]] = None,
         attention_mask: Optional[torch.FloatTensor] = None,
         head_mask: Optional[torch.FloatTensor] = None,
@@ -367,64 +413,80 @@ def forward(
         output_attentions: Optional[bool] = False,
         use_cache: Optional[bool] = False,
     ):
-
         mixed_x_layer = self.c_attn(hidden_states)
+
         query, key, value = mixed_x_layer.split(self.split_size, dim=2)
 
         query = self._split_heads(query, self.num_heads, self.head_dim)
         key = self._split_heads(key, self.num_heads, self.head_dim)
         value = self._split_heads(value, self.num_heads, self.head_dim)
 
-        kv_seq_len = hidden_states.size()[1]
-        if layer_past:
-            # layer past[0] shape: bs * seq_len * head_num * dim
-            kv_seq_len += layer_past[0].shape[1]
-        if (
-            self.use_dynamic_ntk
-            and kv_seq_len == hidden_states.size()[1]
-            and not self.training
-        ):
-            context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
-            ntk_alpha = 2 ** math.ceil(context_value) - 1
-            ntk_alpha = max(ntk_alpha, 1)
-            self._ntk_cached = ntk_alpha
-        else:
-            ntk_alpha = self._ntk_cached
-        rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
-            hidden_states.device
-        )
-
-        if rotary_pos_emb is not None:
-            if isinstance(rotary_pos_emb, tuple):
-                rotary_pos_emb = rotary_pos_emb
-            else:
+        if rotary_pos_emb_list is not None:
+            cur_len = query.shape[1]
+            if len(rotary_pos_emb_list) == 1:
+                rotary_pos_emb = rotary_pos_emb_list[0]
+                rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
                 rotary_pos_emb = (rotary_pos_emb,) * 2
+                q_pos_emb, k_pos_emb = rotary_pos_emb
+                # Slice the pos emb for current inference
+                query = apply_rotary_pos_emb(query, q_pos_emb)
+                key = apply_rotary_pos_emb(key, k_pos_emb)
+            else:
+                query_list = []
+                key_list = []
+                for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
+                    rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
+                    rotary_pos_emb = (rotary_pos_emb,) * 2
+                    q_pos_emb, k_pos_emb = rotary_pos_emb
+                    # Slice the pos emb for current inference
+                    query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
+                    key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
+                query = torch.cat(query_list, dim=0)
+                key = torch.cat(key_list, dim=0)
+
+        if self.use_cache_quantization:
+            key = quantize_cache_v(key.permute(0, 2, 1, 3),
+                                       bits=8,
+                                       qmin=self.cache_qmin,
+                                       qmax=self.cache_qmax)
+            value = quantize_cache_v(value.permute(0, 2, 1, 3),
+                                         bits=8,
+                                         qmin=self.cache_qmin,
+                                         qmax=self.cache_qmax)
 
-        if rotary_pos_emb is not None:
-            q_pos_emb, k_pos_emb = rotary_pos_emb
-            # Slice the pos emb for current inference
-            cur_len = query.shape[1]
-            q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
-            k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
-            query = apply_rotary_pos_emb(query, q_pos_emb)
-            key = apply_rotary_pos_emb(key, k_pos_emb)
 
         if layer_past is not None:
             past_key, past_value = layer_past[0], layer_past[1]
-            key = torch.cat((past_key, key), dim=1)
-            value = torch.cat((past_value, value), dim=1)
+            if self.use_cache_quantization:
+                # use_cache_quantization:
+                # present=((q_key,key_scale,key_zero_point),
+                #          (q_value,value_scale,value_zero_point))
+                key = (torch.cat((past_key[0], key[0]), dim=2),
+                       torch.cat((past_key[1], key[1]), dim=2),
+                       torch.cat((past_key[2], key[2]), dim=2))
+                value = (torch.cat((past_value[0], value[0]), dim=2),
+                         torch.cat((past_value[1], value[1]), dim=2),
+                         torch.cat((past_value[2], value[2]), dim=2))
+            else:
+                # not use_cache_quantization:
+                # present=(key,value)
+                key = torch.cat((past_key, key), dim=1)
+                value = torch.cat((past_value, value), dim=1)
 
         if use_cache:
             present = (key, value)
         else:
             present = None
 
-        if self.use_logn_attn and not self.training:
-            if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
-                self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
-            seq_start = key.size(1) - query.size(1)
-            seq_end = key.size(1)
-            logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
+        key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
+        if key_size > self.seq_length and self.use_logn_attn and not self.training:
+            if self.use_cache_quantization:
+                seq_start = key[0].size(2) - query.size(1)
+                seq_end = key[0].size(2)
+            else:
+                seq_start = key.size(1) - query.size(1)
+                seq_end = key.size(1)
+            logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
             query = query * logn_tensor.expand_as(query)
 
         if (
@@ -434,23 +496,49 @@ def forward(
             and query.is_cuda
         ):
             q, k, v = query, key, value
-            context_layer = self.core_attention_flash(q, k, v)
-
-            context_layer = rearrange(
-                context_layer, "b s h d -> b s (h d)"
-            ).contiguous()
+            attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
         else:
+            key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
+            if query.size(1) == key_size:
+                causal_mask = torch.tril(
+                    torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
+                ).view(1, 1, key_size, key_size)
+            else:
+                causal_mask = None
             query = query.permute(0, 2, 1, 3)
-            key = key.permute(0, 2, 1, 3)
-            value = value.permute(0, 2, 1, 3)
-            attn_output, attn_weight = self._attn(
-                query, key, value, attention_mask, head_mask
-            )
-            context_layer = self._merge_heads(
-                attn_output, self.num_heads, self.head_dim
-            )
+            if not self.use_cache_quantization:
+                key = key.permute(0, 2, 1, 3)
+                value = value.permute(0, 2, 1, 3)
+            if (
+                causal_mask is None
+                and self.use_flash_attn
+                and flash_attn_unpadded_func is not None
+                and not self.is_fp32
+                and not query.is_cuda
+            ):
+                raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
+
+            if not self.use_cache_quantization and SUPPORT_TORCH2:
+                if attention_mask is not None:
+                    attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
+                    if causal_mask is not None:
+                        attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
+                else:
+                    attention_mask = causal_mask
+                attn_output = F.scaled_dot_product_attention(
+                    query, key, value, attn_mask=attention_mask
+                ).transpose(1, 2)
+                attn_weight = None
+            else:
+                attn_output, attn_weight = self._attn(
+                    query, key, value, causal_mask, attention_mask, head_mask
+                )
+        context_layer = self._merge_heads(
+            attn_output, self.num_heads, self.head_dim
+        )
 
         attn_output = self.c_proj(context_layer)
+
         outputs = (attn_output, present)
         if output_attentions:
             if (
@@ -459,6 +547,8 @@ def forward(
                 and not self.is_fp32
             ):
                 raise ValueError("Cannot output attentions while using flash-attn")
+            elif not self.use_cache_quantization and SUPPORT_TORCH2:
+                raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
             else:
                 outputs += (attn_weight,)
 
@@ -469,12 +559,12 @@ class QWenMLP(nn.Module):
     def __init__(self, config):
         super().__init__()
         self.w1 = nn.Linear(
-            config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
+            config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
         )
         self.w2 = nn.Linear(
-            config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
+            config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
         )
-        ff_dim_in = config.ffn_hidden_size // 2
+        ff_dim_in = config.intermediate_size // 2
         self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
 
     def forward(self, hidden_states):
@@ -486,24 +576,16 @@ def forward(self, hidden_states):
 
 
 class QWenBlock(nn.Module):
-    def __init__(self, config, layer_idx=None, num_expert=1):
+    def __init__(self, config):
         super().__init__()
-        self.num_expert = num_expert
-        self.layer_number = layer_idx
-        self.apply_residual_connection_post_layernorm = (
-            config.apply_residual_connection_post_layernorm
-        )
         hidden_size = config.hidden_size
-        self.apply_residual_connection_post_layernorm = (
-            config.apply_residual_connection_post_layernorm
-        )
         self.bf16 = config.bf16
 
         self.ln_1 = RMSNorm(
             hidden_size,
             eps=config.layer_norm_epsilon,
         )
-        self.attn = QWenAttention(config, layer_number=layer_idx)
+        self.attn = QWenAttention(config)
         self.ln_2 = RMSNorm(
             hidden_size,
             eps=config.layer_norm_epsilon,
@@ -514,6 +596,7 @@ def __init__(self, config, layer_idx=None, num_expert=1):
     def forward(
         self,
         hidden_states: Optional[Tuple[torch.FloatTensor]],
+        rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
         layer_past: Optional[Tuple[torch.Tensor]] = None,
         attention_mask: Optional[torch.FloatTensor] = None,
         head_mask: Optional[torch.FloatTensor] = None,
@@ -526,6 +609,7 @@ def forward(
 
         attn_outputs = self.attn(
             layernorm_output,
+            rotary_pos_emb_list,
             layer_past=layer_past,
             attention_mask=attention_mask,
             head_mask=head_mask,
@@ -536,19 +620,12 @@ def forward(
 
         outputs = attn_outputs[1:]
 
-        if self.apply_residual_connection_post_layernorm:
-            residual = layernorm_output
-        else:
-            residual = hidden_states
+        residual = hidden_states
         layernorm_input = attn_output + residual
 
         layernorm_output = self.ln_2(layernorm_input)
 
-        if self.apply_residual_connection_post_layernorm:
-            residual = layernorm_output
-        else:
-            residual = layernorm_input
-
+        residual = layernorm_input
         mlp_output = self.mlp(layernorm_output)
         hidden_states = residual + mlp_output
 
@@ -566,6 +643,7 @@ class QWenPreTrainedModel(PreTrainedModel):
     is_parallelizable = False
     supports_gradient_checkpointing = True
     _no_split_modules = ["QWenBlock"]
+    _skip_keys_device_placement = "past_key_values"
 
     def __init__(self, *inputs, **kwargs):
         super().__init__(*inputs, **kwargs)
@@ -589,7 +667,7 @@ def _init_weights(self, module):
                     mean=0.0,
                     std=(
                         self.config.initializer_range
-                        / math.sqrt(2 * self.config.n_layer)
+                        / math.sqrt(2 * self.config.num_hidden_layers)
                     ),
                 )
 
@@ -603,31 +681,40 @@ class QWenModel(QWenPreTrainedModel):
 
     def __init__(self, config):
         super().__init__(config)
-        self.vocab_size = config.padded_vocab_size
+        self.vocab_size = config.vocab_size
         self.num_hidden_layers = config.num_hidden_layers
         self.embed_dim = config.hidden_size
+        self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False
 
-        max_sequence_length = config.max_position_embeddings
-        self.position_embedding_type = config.pos_emb
         self.gradient_checkpointing = False
+        self.use_dynamic_ntk = config.use_dynamic_ntk
+        self.seq_length = config.seq_length
+
+        self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
 
-        if self.position_embedding_type == "learned":
-            self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
-            self.init_method(self.position_embeddings.weight)
-            self._position_embeddings_key = "position_embeddings"
-            self.init_method(self.position_embeddings.weight)
+        self.drop = nn.Dropout(config.emb_dropout_prob)
+
+        if config.rotary_pct == 1.0:
+            self.rotary_ndims = None
         else:
-            self.wpe = None
-            self._position_embeddings_key = ""
+            assert config.rotary_pct < 1
+            self.rotary_ndims = int(
+                config.kv_channels * config.rotary_pct
+            )
+        dim = (
+            self.rotary_ndims
+            if self.rotary_ndims is not None
+            else config.kv_channels
+        )
+        self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
 
-        self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
+        self.use_flash_attn = config.use_flash_attn
+        self.is_fp32 = not (config.bf16 or config.fp16)
 
-        self.drop = nn.Dropout(config.embd_pdrop)
         self.h = nn.ModuleList(
             [
                 QWenBlock(
-                    config,
-                    layer_idx=i,
+                    config
                 )
                 for i in range(config.num_hidden_layers)
             ]
@@ -645,6 +732,12 @@ def get_input_embeddings(self):
     def set_input_embeddings(self, new_embeddings):
         self.wte = new_embeddings
 
+    def get_ntk_alpha(self, true_seq_len):
+        context_value = math.log(true_seq_len / self.seq_length, 2) + 1
+        ntk_alpha = 2 ** math.ceil(context_value) - 1
+        ntk_alpha = max(ntk_alpha, 1)
+        return ntk_alpha
+
     def forward(
         self,
         input_ids: Optional[torch.LongTensor] = None,
@@ -701,8 +794,10 @@ def forward(
             past_length = 0
             past_key_values = tuple([None] * len(self.h))
         else:
-            past_length = past_key_values[0][0].size(-2)
-
+            if self.use_cache_quantization:
+                past_length = past_key_values[0][0][0].size(2)
+            else:
+                past_length = past_key_values[0][0].size(-2)
         if position_ids is None:
             position_ids = torch.arange(
                 past_length,
@@ -719,17 +814,41 @@ def forward(
             attention_mask = attention_mask[:, None, None, :]
             attention_mask = attention_mask.to(dtype=self.dtype)
             attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-            # attention_mask中mask掉的部分是-inf, 看到的部分是0
 
         encoder_attention_mask = None
-        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
 
         if inputs_embeds is None:
             inputs_embeds = self.wte(input_ids)
         hidden_states = inputs_embeds
-        if self.wpe is not None:
-            position_embeds = self.wpe(position_ids)
-            hidden_states = hidden_states + position_embeds
+
+        kv_seq_len = hidden_states.size()[1]
+        if past_key_values[0] is not None:
+            # past key values[0][0] shape: bs * seq_len * head_num * dim
+            if self.use_cache_quantization:
+                kv_seq_len += past_key_values[0][0][0].shape[2]
+            else:
+                kv_seq_len += past_key_values[0][0].shape[1]
+
+        if self.training or not self.use_dynamic_ntk:
+            ntk_alpha_list = [1.0]
+        elif kv_seq_len != hidden_states.size()[1]:
+            ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
+        else:
+            ntk_alpha_list = []
+            if attention_mask is not None and kv_seq_len > self.seq_length:
+                true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
+                for i in range(hidden_states.size()[0]):
+                    true_seq_len = true_seq_lens[i].item()
+                    ntk_alpha = self.get_ntk_alpha(true_seq_len)
+                    ntk_alpha_list.append(ntk_alpha)
+            else:
+                ntk_alpha = self.get_ntk_alpha(kv_seq_len)
+                ntk_alpha_list.append(ntk_alpha)
+        self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
+        rotary_pos_emb_list = [
+            self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
+        ]
 
         hidden_states = self.drop(hidden_states)
         output_shape = input_shape + (hidden_states.size(-1),)
@@ -761,6 +880,7 @@ def custom_forward(*inputs):
                 outputs = torch.utils.checkpoint.checkpoint(
                     create_custom_forward(block),
                     hidden_states,
+                    rotary_pos_emb_list,
                     None,
                     attention_mask,
                     head_mask[i],
@@ -771,6 +891,7 @@ def custom_forward(*inputs):
                 outputs = block(
                     hidden_states,
                     layer_past=layer_past,
+                    rotary_pos_emb_list=rotary_pos_emb_list,
                     attention_mask=attention_mask,
                     head_mask=head_mask[i],
                     encoder_hidden_states=encoder_hidden_states,
@@ -781,13 +902,16 @@ def custom_forward(*inputs):
 
             hidden_states = outputs[0]
             if use_cache is True:
-                presents = presents + (outputs[2 if output_attentions else 1],)
+                presents = presents + (outputs[1],)
 
             if output_attentions:
-                all_self_attentions = all_self_attentions + (outputs[1],)
+                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
 
         hidden_states = self.ln_f(hidden_states)
         hidden_states = hidden_states.view(output_shape)
+        # Add last hidden state
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
 
         if not return_dict:
             return tuple(
@@ -839,7 +963,7 @@ def __init__(self, config):
                 logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
             elif SUPPORT_FP16:
                 logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
-        
+
         if config.use_flash_attn == "auto":
             if config.bf16 or config.fp16:
                 logger.warn("Try importing flash-attention for faster inference...")
@@ -853,7 +977,7 @@ def __init__(self, config):
             _import_flash_attn()
 
         self.transformer = QWenModel(config)
-        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
         if config.bf16:
             self.transformer.bfloat16()
@@ -872,22 +996,13 @@ def set_output_embeddings(self, new_embeddings):
     def prepare_inputs_for_generation(
         self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
     ):
-        token_type_ids = kwargs.get("token_type_ids", None)
         if past_key_values:
             input_ids = input_ids[:, -1].unsqueeze(-1)
-            if token_type_ids is not None:
-                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
 
-        if attention_mask is not None and position_ids is None:
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past_key_values:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
+        if input_ids.size(0) == 1:
+            attention_mask = None
         else:
-            position_ids = None
+            attention_mask = kwargs.get("attention_mask", None)
 
         if inputs_embeds is not None and past_key_values is None:
             model_inputs = {"inputs_embeds": inputs_embeds}
@@ -898,9 +1013,7 @@ def prepare_inputs_for_generation(
             {
                 "past_key_values": past_key_values,
                 "use_cache": kwargs.get("use_cache"),
-                "position_ids": position_ids,
                 "attention_mask": attention_mask,
-                "token_type_ids": token_type_ids,
             }
         )
         return model_inputs
@@ -987,35 +1100,45 @@ def chat(
         query: str,
         history: Optional[HistoryType],
         system: str = "You are a helpful assistant.",
-        append_history: bool = True,
         stream: Optional[bool] = _SENTINEL,
         stop_words_ids: Optional[List[List[int]]] = None,
+        generation_config: Optional[GenerationConfig] = None,
         **kwargs,
     ) -> Tuple[str, HistoryType]:
+        generation_config = generation_config if generation_config is not None else self.generation_config
+
         assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
-        assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
+        assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
         if history is None:
             history = []
+        else:
+            # make a copy of the user's input such that is is left untouched
+            history = copy.deepcopy(history)
+
         if stop_words_ids is None:
             stop_words_ids = []
 
+        max_window_size = kwargs.get('max_window_size', None)
+        if max_window_size is None:
+            max_window_size = generation_config.max_window_size
         raw_text, context_tokens = make_context(
             tokenizer,
             query,
             history=history,
             system=system,
-            max_window_size=6144,
-            chat_format=self.generation_config.chat_format,
+            max_window_size=max_window_size,
+            chat_format=generation_config.chat_format,
         )
 
         stop_words_ids.extend(get_stop_words_ids(
-            self.generation_config.chat_format, tokenizer
+            generation_config.chat_format, tokenizer
         ))
         input_ids = torch.tensor([context_tokens]).to(self.device)
         outputs = self.generate(
                     input_ids,
-                    stop_words_ids = stop_words_ids,
-                    return_dict_in_generate = False,
+                    stop_words_ids=stop_words_ids,
+                    return_dict_in_generate=False,
+                    generation_config=generation_config,
                     **kwargs,
                 )
 
@@ -1024,13 +1147,16 @@ def chat(
             tokenizer,
             raw_text_len=len(raw_text),
             context_length=len(context_tokens),
-            chat_format=self.generation_config.chat_format,
+            chat_format=generation_config.chat_format,
             verbose=False,
             errors='replace'
         )
 
-        if append_history:
-            history.append((query, response))
+        # as history is a copy of the user inputs,
+        # we can always return the new turn to the user.
+        # separating input history and output history also enables the user
+        # to implement more complex history management
+        history.append((query, response))
 
         return response, history
 
@@ -1042,30 +1168,35 @@ def chat_stream(
             system: str = "You are a helpful assistant.",
             stop_words_ids: Optional[List[List[int]]] = None,
             logits_processor: Optional[LogitsProcessorList] = None,
+            generation_config: Optional[GenerationConfig] = None,
             **kwargs,
     ) -> Generator[str, Any, None]:
-        assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
+        generation_config = generation_config if generation_config is not None else self.generation_config
+        assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
         if history is None:
             history = []
         if stop_words_ids is None:
             stop_words_ids = []
 
+        max_window_size = kwargs.get('max_window_size', None)
+        if max_window_size is None:
+            max_window_size = generation_config.max_window_size
         raw_text, context_tokens = make_context(
             tokenizer,
             query,
             history=history,
             system=system,
-            max_window_size=6144,
-            chat_format=self.generation_config.chat_format,
+            max_window_size=max_window_size,
+            chat_format=generation_config.chat_format,
         )
 
         stop_words_ids.extend(get_stop_words_ids(
-            self.generation_config.chat_format, tokenizer
+            generation_config.chat_format, tokenizer
         ))
         if stop_words_ids is not None:
             stop_words_logits_processor = StopWordsLogitsProcessor(
                 stop_words_ids=stop_words_ids,
-                eos_token_id=self.generation_config.eos_token_id,
+                eos_token_id=generation_config.eos_token_id,
             )
             if logits_processor is None:
                 logits_processor = LogitsProcessorList([stop_words_logits_processor])
@@ -1076,7 +1207,8 @@ def chat_stream(
         from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
         self.__class__.generate_stream = NewGenerationMixin.generate
         self.__class__.sample_stream = NewGenerationMixin.sample_stream
-        stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
+        stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
+
         def stream_generator():
             outputs = []
             for token in self.generate_stream(
@@ -1105,17 +1237,19 @@ def generate(
         streamer: Optional["BaseStreamer"] = None,
         **kwargs,
     ) -> Union[GenerateOutput, torch.LongTensor]:
+        generation_config = generation_config if generation_config is not None else self.generation_config
+
         # Process stop_words_ids.
         stop_words_ids = kwargs.pop("stop_words_ids", None)
         if stop_words_ids is None and generation_config is not None:
             stop_words_ids = getattr(generation_config, "stop_words_ids", None)
         if stop_words_ids is None:
-            stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
+            stop_words_ids = getattr(generation_config, "stop_words_ids", None)
 
         if stop_words_ids is not None:
             stop_words_logits_processor = StopWordsLogitsProcessor(
                 stop_words_ids=stop_words_ids,
-                eos_token_id=self.generation_config.eos_token_id,
+                eos_token_id=generation_config.eos_token_id,
             )
             if logits_processor is None:
                 logits_processor = LogitsProcessorList([stop_words_logits_processor])
@@ -1140,16 +1274,17 @@ def __init__(self, dim, base=10000):
         super().__init__()
         self.dim = dim
         self.base = base
-        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
         if importlib.util.find_spec("einops") is None:
             raise RuntimeError("einops is required for Rotary Embedding")
 
         self._rotary_pos_emb_cache = None
         self._seq_len_cached = 0
         self._ntk_alpha_cached = 1.0
+        self._ntk_alpha_cached_list = [1.0]
 
-    def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
-        seqlen = max_seq_len + offset
+    def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0):
         if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
             base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
             self.inv_freq = 1.0 / (
@@ -1163,14 +1298,19 @@ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
             self._ntk_alpha_cached = ntk_alpha
             seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
             freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
+
             emb = torch.cat((freqs, freqs), dim=-1)
             from einops import rearrange
 
-            self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
+            emb = rearrange(emb, "n d -> 1 n 1 d")
 
-    def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
-        self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
-        return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
+            cos, sin = emb.cos(), emb.sin()
+            self._rotary_pos_emb_cache = [cos, sin]
+
+    def forward(self, max_seq_len, ntk_alpha=1.0):
+        self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha)
+        cos, sin = self._rotary_pos_emb_cache
+        return [cos[:, :max_seq_len], sin[:, :max_seq_len]]
 
 
 def _rotate_half(x):
@@ -1182,20 +1322,28 @@ def _rotate_half(x):
 
 
 def apply_rotary_pos_emb(t, freqs):
-    if apply_rotary_emb_func is not None:
-        t_ = t.float()
-        freqs = freqs.squeeze(0).squeeze(1)
-        cos = freqs[:, : freqs.shape[-1] // 2].cos()
-        sin = freqs[:, : freqs.shape[-1] // 2].sin()
-        output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
-        return output
+    """ Apply rotary embedding to the first rotary_dim of the iput
+
+    Arguments:
+      t (tensor(batch_size, seq_len, n_head, head_dim)):
+        the input embedding/hidden states
+      freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
+        the cached cos/sin position embeddings
+    """
+    rot_dim = freqs[0].shape[-1]
+    cos, sin = freqs
+    t_float = t.float()
+    if apply_rotary_emb_func is not None and t.is_cuda:
+        # apply_rotary_emb in flash_attn requires cos/sin to be of
+        # shape (seqlen, rotary_dim / 2) and apply rotary embedding
+        # to the first rotary_dim of the input
+        cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
+        sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
+        return apply_rotary_emb_func(t_float, cos, sin).type_as(t)
     else:
-        rot_dim = freqs.shape[-1]
-        t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
-        t_ = t_.float()
-        t_pass_ = t_pass_.float()
-        t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
-        return torch.cat((t_, t_pass_), dim=-1).type_as(t)
+        t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
+        t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
+        return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
 
 
 class RMSNorm(torch.nn.Module):
diff --git a/mft_peft_hf/src/model/qwen/qwen_generation_utils.py b/mftcoder_accelerate/src/model/qwen/qwen_generation_utils.py
similarity index 100%
rename from mft_peft_hf/src/model/qwen/qwen_generation_utils.py
rename to mftcoder_accelerate/src/model/qwen/qwen_generation_utils.py
diff --git a/mft_peft_hf/src/model/qwen/tokenization_qwen.py b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py
similarity index 76%
rename from mft_peft_hf/src/model/qwen/tokenization_qwen.py
rename to mftcoder_accelerate/src/model/qwen/tokenization_qwen.py
index 4a66a09..2a526d6 100644
--- a/mft_peft_hf/src/model/qwen/tokenization_qwen.py
+++ b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py
@@ -27,11 +27,22 @@
 # regular texts, the surface forms of special tokens need to be
 # as different as possible to minimize the impact
 EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
-SPECIAL_TOKENS = (
-    ENDOFTEXT,
-    IMSTART,
-    IMEND,
-) + EXTRAS
+# changed to use actual index to avoid misconfiguration with vocabulary expansion
+SPECIAL_START_ID = 151643
+SPECIAL_TOKENS = tuple(
+    enumerate(
+        (
+            (
+                ENDOFTEXT,
+                IMSTART,
+                IMEND,
+            )
+            + EXTRAS
+        ),
+        start=SPECIAL_START_ID,
+    )
+)
+SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
 
 
 def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
@@ -42,6 +53,7 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
         for token, rank in (line.split() for line in contents.splitlines() if line)
     }
 
+
 class QWenTokenizer(PreTrainedTokenizer):
     """QWen tokenizer."""
 
@@ -51,20 +63,35 @@ def __init__(
         self,
         vocab_file,
         errors="replace",
+        extra_vocab_file=None,
         **kwargs,
     ):
         super().__init__(**kwargs)
 
-        self.errors = errors  # how to handle errors in decoding
+        # how to handle errors in decoding UTF-8 byte sequences
+        # use ignore if you are in streaming inference
+        self.errors = errors  
 
-        self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)  # type: dict[bytes, int]
+        self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)  # type: Dict[bytes, int]
         self.special_tokens = {
             token: index
-            for index, token in enumerate(
-                SPECIAL_TOKENS, start=len(self.mergeable_ranks)
-            )
+            for index, token in SPECIAL_TOKENS
         }
 
+        # try load extra vocab from file
+        if extra_vocab_file is not None:
+            used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
+            extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
+            for token, index in extra_mergeable_ranks.items():
+                if token in self.mergeable_ranks:
+                    logger.info(f"extra token {token} exists, skipping")
+                    continue
+                if index in used_ids:
+                    logger.info(f'the index {index} for extra token {token} exists, skipping')
+                    continue
+                self.mergeable_ranks[token] = index
+            # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
+
         enc = tiktoken.Encoding(
             "Qwen",
             pat_str=PAT_STR,
@@ -86,6 +113,23 @@ def __init__(
         self.im_start_id = self.special_tokens[IMSTART]
         self.im_end_id = self.special_tokens[IMEND]
 
+    def __getstate__(self):
+        # for pickle lovers
+        state = self.__dict__.copy()
+        del state["tokenizer"]
+        return state
+
+    def __setstate__(self, state):
+        # tokenizer is not python native; don't pass it; rebuild it
+        self.__dict__.update(state)
+        enc = tiktoken.Encoding(
+            "Qwen",
+            pat_str=PAT_STR,
+            mergeable_ranks=self.mergeable_ranks,
+            special_tokens=self.special_tokens,
+        )
+        self.tokenizer = enc
+
     def __len__(self) -> int:
         return self.tokenizer.n_vocab
 
@@ -108,13 +152,17 @@ def convert_tokens_to_ids(
                 ids.append(self.mergeable_ranks.get(token))
         return ids
 
-    def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
+    def _add_tokens(
+        self,
+        new_tokens: Union[List[str], List[AddedToken]],
+        special_tokens: bool = False,
+    ) -> int:
         if not special_tokens and new_tokens:
-            raise ValueError('Adding regular tokens is not supported')
+            raise ValueError("Adding regular tokens is not supported")
         for token in new_tokens:
             surface_form = token.content if isinstance(token, AddedToken) else token
-            if surface_form not in SPECIAL_TOKENS:
-                raise ValueError('Adding unknown special tokens is not supported')
+            if surface_form not in SPECIAL_TOKENS_SET:
+                raise ValueError("Adding unknown special tokens is not supported")
         return 0
 
     def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
diff --git a/mftcoder_accelerate/src/model/qwen/tokenizer_config.json b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json
new file mode 100644
index 0000000..9c37cac
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json
@@ -0,0 +1,10 @@
+{
+  "model_max_length": 8192,
+  "tokenizer_class": "QWenTokenizer",
+  "auto_map": {
+    "AutoTokenizer": [
+      "tokenization_qwen.QWenTokenizer",
+      null
+      ]
+  }
+}
diff --git a/mftcoder_accelerate/src/mpt/mpt_accelerate.py b/mftcoder_accelerate/src/mpt/mpt_accelerate.py
new file mode 100644
index 0000000..5d187c9
--- /dev/null
+++ b/mftcoder_accelerate/src/mpt/mpt_accelerate.py
@@ -0,0 +1,494 @@
+"""
+# @author Chaoyu Chen
+# @date 2024/6/1
+# @module mpt_accelerate.py
+
+Accelerate + DeepSpeed + Full-parameter + Multi-task + Pre-training/Continue Training/Finetuning
+
+Entry
+"""
+
+import os
+import sys
+import argparse
+import math
+import logging
+import json
+import time
+from tqdm.auto import tqdm
+import transformers
+import numpy as np
+import torch
+from torch import nn
+from dataclasses import dataclass
+from datasets import Dataset
+import datasets
+from torch.utils.data import DataLoader
+from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
+
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    get_linear_schedule_with_warmup,
+    set_seed,
+    BitsAndBytesConfig,
+    get_scheduler,
+)
+
+from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration
+from accelerate.logging import get_logger
+from datetime import timedelta
+from accelerate.utils import InitProcessGroupKwargs
+from transformers.optimization import Adafactor
+
+# insert src as import path
+current_path = os.path.abspath(__file__)
+parent_dir = os.path.dirname(os.path.dirname(current_path))
+sys.path.insert(0, parent_dir)
+
+from tokenizer import build_tokenizer
+from data.multi_task_dataset import load_dataset_from_jsonl, compile_helper
+from data.data_utils import load_dataset_from_bin
+from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK
+from mpt.mpt_trainer import MptTrainer
+from mpt.mpt_arguments import MptTrainArgs
+from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS
+
+
+logger = get_logger(__name__)
+
+
+def get_task_mask(args, task_id):
+    task_num = len(TASK2ID)
+    task_mask = torch.zeros(task_id.shape[0], task_num)
+    task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1
+
+    return task_mask
+
+
+def get_attention_mask_and_position_ids(data):
+    """Build masks and position id for left to right model."""
+
+    # Extract batch size and sequence length.
+    batch_size, seq_length = data.size()
+
+    attention_mask = torch.ones((batch_size, seq_length), device=data.device)
+
+    # Position ids.
+    position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
+    position_ids = position_ids.unsqueeze(0).expand_as(data).clone()
+
+    return attention_mask, position_ids
+
+
+@dataclass
+class DataCollatorForMFTDataset(object):
+    args: None
+
+    def __call__(self, instances):
+        (input_ids, loss_mask, weights, task_id) = tuple(
+            [instance.get(key, None) for instance in instances]
+            for key in ("input_ids", "loss_mask", "weight", "task_id")
+        )
+
+        result_batch = {}
+        """
+        outputs = model(
+                    input_ids=batch['input_ids'],
+                    attention_mask=batch['attention_mask'],
+                    # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']),
+                    # labels=(batch['labels'], batch['loss_mask']),
+                    position_ids=batch['position_ids'])
+        """
+
+        # if loss_mask is not None:
+        loss_mask = torch.tensor(np.array(loss_mask)).long()
+        last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1)
+        if self.args.use_dynamic_padding:
+            # get last non-padding position
+            max_pos = last_one_pos.max().item() + 1
+        else:
+            max_pos = loss_mask.shape[-1]
+
+        if self.args.tokenize_mode == "sst" and self.args.padding_mode == "pack":
+            # 兼容sst + pack tokenization, 最后一位是脏数据,需要去掉
+            result_batch["loss_mask"] = loss_mask.float()[:, 1 : max_pos - 1].contiguous()
+            input_ids = torch.tensor(np.array(input_ids)).long()
+            result_batch["input_ids"] = input_ids[:, : max_pos - 2].contiguous()
+            result_batch["labels"] = input_ids[:, 1 : max_pos - 1].contiguous()
+        else:
+            result_batch["loss_mask"] = loss_mask.float()[:, 1:max_pos].contiguous()
+            input_ids = torch.tensor(np.array(input_ids)).long()
+            # print(f"shape of input_ids: {input_ids.shape}")
+            result_batch["input_ids"] = input_ids[:, : max_pos - 1].contiguous()
+            result_batch["labels"] = input_ids[:, 1:max_pos].contiguous()
+
+        # Get the masks and position ids.
+
+        # if you want to be compatible with non-gpt models, something you can do here
+        if self.args.model_type in ["antglm"]:
+            (result_batch["attention_mask"], result_batch["position_ids"]) = get_attention_mask_and_position_ids(
+                data=result_batch["input_ids"]
+            )
+        elif self.args.model_type in ["mixtral", "mtx-qwen2", "qwen2_moe"]:
+            batch_size, seq_length = result_batch["input_ids"].shape
+            # bsz * seq_length
+            range_tensor = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1)
+            # attention_mask for padding tokens
+            attention_mask = (range_tensor <= last_one_pos.reshape(batch_size, 1)).long()
+            result_batch["attention_mask"], result_batch["position_ids"] = attention_mask, None
+        else:
+            # For decoder-only models, transformers will create them.
+            result_batch["attention_mask"], result_batch["position_ids"] = None, None
+
+        if task_id is not None:
+            task_id = torch.tensor(np.array(task_id))
+            result_batch["task_mask"] = get_task_mask(self.args, task_id)  # bsz * task_num
+            result_batch["task_id"] = task_id
+
+        return result_batch
+
+
+def pprint_args(args, accelerator):
+    # 计算所有键的最大字符串长度
+    max_key_length = max(len(str(key)) for key in vars(args).keys())
+
+    message = ""
+    message += "====" * 60 + "\n"
+    message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n"
+    message += "====" * 60 + "\n"
+    accelerator.print(message)
+    accelerator.print("GPU: {}".format(torch.cuda.current_device()))
+
+
+def prepare_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--train_config", type=str, default=None)
+
+    parser.add_argument("--data_paths", type=str, default=None)
+    parser.add_argument("--output_dir", type=str, default=None)
+    parser.add_argument("--tb_dir", type=str, default=None)
+    parser.add_argument("--pretrained_model_path", type=str, default=None)
+    parser.add_argument("--micro_batch_size", type=int, default=None)
+    parser.add_argument("--model_type", type=str, default=None)
+    parser.add_argument("--distributed_type", type=str, default="deepspeed")
+
+    parsed = parser.parse_args()
+    # get json configs
+    with open(parsed.train_config, "r") as f:
+        train_config = json.load(f)
+
+    # parse args from cofig.json
+    # args = argparse.Namespace(**train_config)
+    args = MptTrainArgs(**train_config)
+
+    # override args by cli arguments
+    if parsed.data_paths:
+        args.data_paths = parsed.data_paths
+    if parsed.output_dir:
+        args.output_dir = parsed.output_dir
+    if parsed.tb_dir:
+        args.tb_dir = parsed.tb_dir
+    if parsed.pretrained_model_path:
+        args.pretrained_model_path = parsed.pretrained_model_path
+        args.vocab_file = parsed.pretrained_model_path
+    if parsed.micro_batch_size:
+        args.per_device_train_batch_size = parsed.micro_batch_size
+        args.per_device_eval_batch_size = parsed.micro_batch_size
+    if parsed.model_type:
+        args.model_type = parsed.model_type
+
+    args.distributed_type = parsed.distributed_type
+
+    # refactor args
+
+    args.vocab_file = args.pretrained_model_path
+
+    args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(","))) + "]"
+
+    # generate TASK2ID, ID2TASK
+    generate_task_id(args.data_paths)
+
+    if args.weighted_loss_mode == "coba":
+        args.task_weights = [1.0] * len(ID2TASK)
+    elif args.task_weights is not None:
+        args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")]
+        assert len(args.task_weights) == len(ID2TASK), f"length of task_weights must equal to length of data_paths"
+    else:
+        args.task_weights = [1.0] * len(ID2TASK)
+
+    return args
+
+
+def main():
+    t0 = time.time()
+    os.environ["TOKENIZERS_PARALLELISM"] = "false"
+    os.environ["HF_HUB_OFFLINE"] = "false"
+    # get input args, set TASK2ID, ID2TASK, refactor args
+    args = prepare_args()
+
+    # fix randomness
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    # define accelerator
+    init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds))
+    
+    if args.distributed_type and args.distributed_type.lower() == "fsdp":
+        fsdp_plugin = FullyShardedDataParallelPlugin(
+            # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
+            # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
+            limit_all_gathers=True,
+            sync_module_states=True,
+            use_orig_params=True,
+            cpu_offload=False,
+        )
+        accelerator = Accelerator(
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            fsdp_plugin=fsdp_plugin,
+            dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True),
+            kwargs_handlers=[init_process_kwargs],
+        )
+    else:
+        accelerator = Accelerator(
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True),
+            kwargs_handlers=[init_process_kwargs],
+        )
+
+    # print key infos
+    accelerator.print("In mft_accelerate.py, sys path:", sys.path)
+    accelerator.print(f"transformers.__version__: {transformers.__version__}")
+
+    # get world_size
+    args.world_size = accelerator.num_processes
+
+    # backup args
+    pprint_args(args, accelerator)
+    if accelerator.is_main_process:
+        if not os.path.exists(args.output_dir):
+            os.makedirs(args.output_dir)
+        with open(os.path.join(args.output_dir, "args.json"), "w") as f:
+            json.dump(args.dict(), f, indent=2)
+
+    # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest
+    latest = None
+    if os.path.exists(os.path.join(args.output_dir, "latest")):
+        with open(os.path.join(args.output_dir, "latest"), "r") as fl:
+            latest = json.load(fl)
+        accelerator.print(f"[INFO] Existing latest: {latest}")
+
+    if args.auto_resume and args.resume_from_checkpoint is None and latest:
+        args.resume_from_checkpoint = latest["latest_ckpt"]
+
+    # logger
+    logging.basicConfig(
+        format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state, main_process_only=False)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_info()
+        # compile Cpp helper
+        compile_helper()
+        time.sleep(10)
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+
+    # get global_rank and local rank for current process
+    global_rank = accelerator.process_index
+    local_rank = accelerator.local_process_index
+    print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}")
+
+    # TASK2ID, ID2TASK
+    # generate_task_id(args.data_paths)
+
+    # multi task blendable dataset(sharded)
+    if args.load_raw_dataset:
+        print_rank_0("> load raw jsonl dataset")
+        train_dataset, valid_dataset = load_dataset_from_jsonl(
+            args=args, shard_data=True, world_size=args.world_size, global_rank=global_rank, local_rank=local_rank
+        )
+    else:
+        print_rank_0("> load tokenized bin dataset, refer to gpt_neox indexed dataset")
+        train_dataset, valid_dataset, _ = load_dataset_from_bin(args=args)
+
+    t1 = time.time()
+    logger.info(f"dataset loading time: {t1 - t0:.4f}")
+
+    # cuda memory
+    free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
+    max_memory = f"{free_in_GB - 2}GB"
+    n_gpus = torch.cuda.device_count()
+    max_memory = {i: max_memory for i in range(n_gpus)}
+    accelerator.print("max memory: ", max_memory, n_gpus)
+
+    # # 是否要加入新的special tokens
+    # num_added_toks = tokenizer.tokenizer.add_special_tokens(["", ""])
+    # accelerator.print("We have added", num_added_toks, "tokens")
+    # accelerator.print(f"role marker tokens {tokenizer.convert_tokens_to_ids('')} {tokenizer.convert_tokens_to_ids('')}, resized tokenizer_size: {len(tokenizer)}")
+
+    # creating model
+    ModelClass = MODEL_TYPES[args.model_type]
+    if args.model_type in SUPPORT_IN_TRANSFORMERS:
+        accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers")
+        model = ModelClass.from_pretrained(
+            args.pretrained_model_path,
+            attn_implementation=args.attn_implementation,
+            torch_dtype=torch.bfloat16,
+        )
+    else:
+        accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code")  
+        model = ModelClass.from_pretrained(
+            args.pretrained_model_path,
+            torch_dtype=torch.bfloat16,
+        )
+
+    # build a tokenizer for possible resizing or saving
+    tokenizer = build_tokenizer(args)
+    # Note: resize_token_embeddings expects to receive the full size of the new vocabulary,
+    # i.e. the length of the tokenizer.
+    # 如果新增special tokens, 需要resize input embedding 和output embedding
+    # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
+
+    model.gradient_checkpointing_enable()
+
+    if args.saving_limit is None or not isinstance(args.saving_limit, int) or args.saving_limit < 1:
+        # saving_limit is set automatically if needed
+        args.saving_limit = 2
+        accelerator.print(
+            "[WARNING]saving_limit must be a integer greater than 1 in Full-Parameters Training, we set it to 2"
+        )
+
+    t2 = time.time()
+    if accelerator.is_main_process:
+        logging.info(f"model loading time: {t2 - t1:.4f}")
+
+    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
+    if hasattr(model.config, "use_logn_attn"):
+        model.config.use_logn_attn = False  # special for qwen model
+    # load balance for moe training
+    if hasattr(model.config, "output_router_logits"):
+        model.config.output_router_logits = True
+    model_config = model.config
+    accelerator.print(model.config)
+
+    # dataloader
+    train_dataloader = DataLoader(
+        train_dataset,
+        shuffle=True,
+        collate_fn=DataCollatorForMFTDataset(args),
+        batch_size=args.per_device_train_batch_size,
+        pin_memory=True,
+        drop_last=True,
+    )
+    if valid_dataset:
+        valid_dataloader = DataLoader(
+            valid_dataset,
+            collate_fn=DataCollatorForMFTDataset(args),
+            batch_size=args.per_device_eval_batch_size,
+            pin_memory=True,
+            drop_last=True,
+        )
+    else:
+        valid_dataloader = None
+
+    # optimizer
+    if accelerator.distributed_type == DistributedType.DEEPSPEED:
+        accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED")
+        # from deepspeed.ops.adam import FusedAdam as Adam
+        # adam_optimizer = Adam
+        adam_optimizer = torch.optim.AdamW
+    elif accelerator.distributed_type == DistributedType.FSDP:
+        accelerator.print("DISTRIBUTED TRAINING USING FSDP")
+        model = accelerator.prepare(model)
+        adam_optimizer = torch.optim.AdamW
+    else:
+        raise ValueError("Only support DeepSpeed and FSDP")
+
+    optimizer = adam_optimizer(
+        model.parameters(),
+        weight_decay=args.weight_decay,
+        lr=args.learning_rate,
+        betas=(0.9, 0.999),
+    )
+    # for group in optimizer.param_groups:
+    #     group.setdefault("initial_lr", group["lr"])
+
+    # Scheduler and math around the number of training steps.
+    overrode_max_train_steps = False
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+        overrode_max_train_steps = True
+    if isinstance(args.num_warmup_steps, float) and args.num_warmup_steps < 1.0:
+        args.num_warmup_steps = int(args.max_train_steps * args.num_warmup_steps) // accelerator.num_processes
+    accelerator.print(f"num_warmup_steps: {args.num_warmup_steps}")
+    lr_scheduler = get_scheduler(
+        name=args.lr_scheduler_type,
+        optimizer=optimizer,
+        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes,
+        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+        # scheduler_specific_kwargs={"last_epoch": scheduler_last_ep}
+    )
+    # prepare all
+    if accelerator.distributed_type == DistributedType.DEEPSPEED:
+        if valid_dataloader:
+            (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare(
+                model, train_dataloader, valid_dataloader, optimizer, lr_scheduler
+            )
+        else:
+            (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare(
+                model, train_dataloader, optimizer, lr_scheduler
+            )
+
+    # prepare all except model, which is prepared before
+    elif accelerator.distributed_type == DistributedType.FSDP:
+        if valid_dataloader:
+            (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare(
+                optimizer, train_dataloader, valid_dataloader, lr_scheduler
+            )
+        else:
+            (optimizer, train_dataloader, lr_scheduler) = accelerator.prepare(
+                optimizer, train_dataloader, lr_scheduler
+            )
+    print(model.device)
+    accelerator.print(model)
+    # accelerator.print(model.config)
+
+    # Recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if overrode_max_train_steps:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+    # Afterward we recalculate our number of training epochs
+    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    # zero 3 flag
+    is_ds_zero_3 = False
+    if getattr(accelerator.state, "deepspeed_plugin", None):
+        is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3
+        accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}")
+    elif getattr(accelerator.state, "fsdp_plugin", None):
+        accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}")
+
+    trainer = MptTrainer(
+        accelerator=accelerator,
+        model=model,
+        model_config=model_config,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        optimizer=optimizer,
+        lr_scheduler=lr_scheduler,
+        tokenizer=tokenizer,
+        num_update_steps_per_epoch=num_update_steps_per_epoch,
+        total_train_dataset_size=len(train_dataset),
+        args=args,
+    )
+    trainer.accelerate_train()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/mftcoder_accelerate/src/mpt/mpt_arguments.py b/mftcoder_accelerate/src/mpt/mpt_arguments.py
new file mode 100644
index 0000000..8045421
--- /dev/null
+++ b/mftcoder_accelerate/src/mpt/mpt_arguments.py
@@ -0,0 +1,161 @@
+"""
+# @author Chaoyu Chen
+# @date 2024/6/1
+
+MPT training arguments
+"""
+
+from dataclasses import dataclass, asdict
+from typing import List, Union
+
+
+@dataclass
+class MptTrainArgs:
+    # train data paths on shared FS
+    data_paths: Union[str, List[str]]
+
+    # output dir for saving adaptors in peft or full ckpts in full-parameter training
+    output_dir: str
+
+    # tensorboard dir for saving tensorboard logs
+    tb_dir: str
+
+    # pretrained_model_path, on which is the model you want to train
+    pretrained_model_path: str
+
+    # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2
+    model_type: str
+
+    # load from raw jsonl file or tokenized binary file
+    load_raw_dataset: bool = True
+
+    # weights of loss calculation for each task, None means equal weights
+    task_weights: Union[None, str] = None
+
+    # weights of data sampling, leave it None
+    data_weights: Union[None, str] = None
+
+    # hf loading model low_cpu_mem_usage
+    low_cpu_mem_usage: bool = True
+
+    # train/valid/test split
+    data_split: str = "98,2,0"
+
+    # padding or pack or concat
+    padding_mode: str = "padding"
+
+    # sft or sst
+    tokenize_mode: str = "sft"
+
+    # case3 or case4
+    weighted_loss_mode: str = "case3"
+
+    # mircro train batch size
+    per_device_train_batch_size: int = 8
+
+    # micro eval batch size, always same as micro train batch size
+    per_device_eval_batch_size: int = 8
+
+    # HF AutoTokenizer is supported, maybe more types
+    tokenizer_type: str = "AutoTokenizer"
+
+    # initial lr
+    learning_rate: float = 5e-5
+
+    # minimum lr
+    min_lr: float = 5e-6
+
+    # weight decay
+    weight_decay: float = 0.01
+
+    # gradient_accumulation_steps
+    gradient_accumulation_steps: int = 1
+
+    # lr_scheduler_type
+    lr_scheduler_type: str = "cosine"
+
+    # num_warmup_steps
+    num_warmup_steps: Union[int, float] = 0.05
+
+    # num_train_epochs
+    num_train_epochs: int = 4
+
+    # seed for reproducing
+    seed: int = 1234
+
+    # seq_length, context length
+    seq_length: int = 4096
+
+    # path of adaptor which is resumed from, None for not resuming training
+    resume_from_checkpoint: Union[None, str] = None
+
+    # auto resume from latest ckpt if job restarted
+    auto_resume: bool = True
+
+    # num of steps for logging training loss
+    log_interval: int = 10
+
+    # num of steps for saving ckpt
+    checkpointing_steps: int = 100
+
+    # num of steps for evaluation(eval_loss), better same as checkpointing steps
+    evaluation_steps: int = 100
+
+    # max train steps, if None, depends on num_train_epochs
+    max_train_steps: Union[None, int] = None
+
+    # if checkpointing every epoch, maybe True in sst
+    epoch_checkpointing: bool = False
+
+    # save transformers model(safetensors)
+    save_transformers_model: bool = False
+
+    # shuffle before train/valid split
+    shuffle_before_split: bool = True
+
+    # DDP random sampler
+    use_random_sampler: bool = True
+
+    # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point
+    early_stopping: bool = True
+    early_stopping_stall_num: int = 5
+
+    # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota.
+    saving_limit: Union[None, int] = None
+
+    # if dynamic padding
+    use_dynamic_padding: bool = True
+
+    # warm-up steps for CoBa, recommand the number of valid batches
+    coba_warmup_steps: int = 100
+    # history length of sample valid loss used to fit the slope curve in CoBa, recommand [2*coba_warmup_steps,5*coba_warmup_steps]
+    coba_history_length: int = 200
+    # temperature for divergence factor in CoBa
+    coba_tau: int = 5
+    # iteration interval of update per task train weight in CoBa
+    coba_update_interval: int = 1
+    # the number of mini valid batches sampled at each updated iteration interval
+    coba_sample_valid_num: int = 1
+
+    # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2}
+    attn_implementation: str = "flash_attention_2"
+
+    # role markers, which are prompt template before each role: system, user and assistant
+    # role_markers: {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"}
+    role_markers: Union[None, dict] = None
+
+    distributed_type: Union[None, str] = None
+
+    init_timeout_seconds: Union[None, int] = 3600
+
+    # legacy, leave them
+    use_xformers: bool = True
+    trust_remote_code: bool = True
+    weight_by_num_documents: bool = True
+    make_vocab_size_divisible_by: int = 32
+    model_parallel_size: int = 1
+    use_slow_tokenizer: bool = False
+    world_size: int = 8
+
+    def dict(self):
+        return {k: str(v) for k, v in asdict(self).items()}
diff --git a/mftcoder_accelerate/src/mpt/mpt_trainer.py b/mftcoder_accelerate/src/mpt/mpt_trainer.py
new file mode 100644
index 0000000..b5e2da8
--- /dev/null
+++ b/mftcoder_accelerate/src/mpt/mpt_trainer.py
@@ -0,0 +1,606 @@
+"""
+# @author qumu
+# @date 2024/6/6
+# @module mpt_trainer.py
+
+MPT/MCT/MFT Full-parameter Trainer
+"""
+
+import gc
+import os
+import sys
+import threading
+import argparse
+import math
+import logging
+import json
+import time
+import transformers
+import numpy as np
+import psutil
+import shutil
+import torch
+from torch import nn
+from torch.utils.tensorboard import SummaryWriter
+from typing import List, Optional, Tuple, Union
+from tqdm.auto import tqdm
+from accelerate.logging import get_logger
+from accelerate import Accelerator
+from transformers import set_seed
+
+# sys.path.append("..")
+from utils.common_utils import generate_task_id, TASK2ID, ID2TASK
+from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func
+
+logger = get_logger(__name__)
+
+
+def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str):
+    # create path if not exist
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    # copy each file in files_list to save_path
+    for filename in files_list:
+        src_file = os.path.join(mode_path, filename)
+
+        # copy only if src exists
+        if os.path.exists(src_file):
+            dest_file = os.path.join(save_path, filename)
+
+            # copy
+            shutil.copy(src_file, dest_file)
+            print(f"Copied {filename} to {save_path}")
+        else:
+            print(f"File {filename} does not exist in {mode_path}")
+
+
+def check_existing_ckpts(output_dir):
+    prefix = "step_"
+
+    if not os.path.exists(output_dir):
+        return []
+    # list all files and dirs
+    contents = os.listdir(output_dir)
+
+    # find dirs starts with "step_"
+    matching_folders = [
+        folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix)
+    ]
+
+    return matching_folders
+
+
+def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps):
+    """
+    extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training
+    """
+    # Extract `epoch_{i}` or `step_{i}`
+    training_difference = os.path.splitext(path)[0]
+
+    if "epoch" in training_difference:
+        starting_epoch = int(training_difference.replace("epoch_", ""))
+        resume_step = None
+        completed_steps = starting_epoch * num_update_steps_per_epoch
+        logger.info(f"Resume from exact Epoch {starting_epoch}: completed_steps {completed_steps}")
+    else:
+        # need to multiply `gradient_accumulation_steps` to reflect real steps
+        completed_steps = int(training_difference.replace("step_", ""))
+        starting_epoch = completed_steps // num_update_steps_per_epoch
+        resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps
+        logger.info(f"Resume from Epoch {starting_epoch} + step {resume_step}: completed_steps {completed_steps}")
+
+    return starting_epoch, completed_steps, resume_step
+
+
+def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
+    for key, value in log_dict.items():
+        summary_writer.add_scalar(f"{key}", value, completed_steps)
+
+
+def delete_ckpts_over_limits(output_dir, saving_limit, best_step):
+    """delete ckpts more than saving_limits except for the best_step ckpt"""
+    existing_ckpts = check_existing_ckpts(output_dir)
+    logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}")
+    # sorted only step num ascendingly
+    ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts])
+    # delete the oldest steps except for the best step at present
+    if len(ckpt_steps) > saving_limit:
+        deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step]
+        # print(deletable_steps[:len(ckpt_steps) - saving_limit])
+        for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]:
+            shutil.rmtree(os.path.join(output_dir, f"step_{del_step}"))
+            logger.info(f"Removed ckpt step_{del_step}")
+
+
+class MptTrainer:
+    """
+    Multitask Pre-train/Continue-train Trainer with Full-parameters training.
+    """
+
+    def __init__(
+        self,
+        accelerator: Accelerator,
+        model,
+        model_config,
+        train_dataloader,
+        valid_dataloader,
+        optimizer,
+        lr_scheduler,
+        tokenizer,
+        num_update_steps_per_epoch,
+        total_train_dataset_size,
+        args,
+    ):
+        self.accelerator = accelerator
+        self.model = model
+        # hf model config
+        self.model_config = model_config
+        self.train_dataloader = train_dataloader
+        self.valid_dataloader = valid_dataloader
+        self.optimizer = optimizer
+        self.lr_scheduler = lr_scheduler
+        self.tokenizer = tokenizer
+        self.num_update_steps_per_epoch = num_update_steps_per_epoch
+        self.total_train_dataset_size = total_train_dataset_size
+        # training arguments
+        self.args = args
+        # tensorboard writer
+        self.summary_writer = SummaryWriter(log_dir=args.tb_dir)
+
+    def print(self, msg: str):
+        """
+        accelerator print, default on main process
+        Args:
+            msg:
+
+        Returns:
+
+        """
+        self.accelerator.print(msg)
+
+    def touch(self, batch, num_tokens=10):
+        """touch first and last tokens and labels for debugging usage"""
+        self.print(
+            f"step 1 batch shape: {batch['input_ids'].shape},\n"
+            f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}"
+            f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}"
+        )
+        self.print(f"first {num_tokens} input_ids and loss_mask")
+        for pt in range(1):
+            self.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
+            self.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
+
+    @staticmethod
+    def format_tensor(tensor, n):
+        return list(map(lambda x: round(x, n), tensor.tolist()))
+
+    def accelerate_saving_states(self, output_dir: str, completed_steps: int):
+        """
+        Saving lora adaptor or full checkpoint using accelerator
+        Args:
+            output_dir: exact dir for saving ckpt
+            completed_steps:
+
+        Returns:
+
+        """
+        self.accelerator.wait_for_everyone()
+        logger.info(f"[CHECKPOINT] Saving checkpoint states")
+        self.accelerator.save_state(output_dir)
+        self.accelerator.wait_for_everyone()
+
+        # save safetensors for direct inference if needed
+        if self.args.save_transformers_model:
+            logger.info(f"[CHECKPOINT] Saving transformers(hf) model", main_process_only=True)
+            unwrapped_model = self.accelerator.unwrap_model(self.model)
+            # self.print(f"unwrapped model type {type(unwrapped_model)}")
+            unwrapped_model.save_pretrained(
+                output_dir,
+                is_main_process=self.accelerator.is_main_process,
+                save_function=self.accelerator.save,
+                state_dict=self.accelerator.get_state_dict(self.model),
+            )
+            self.accelerator.wait_for_everyone()
+
+        # tokenizer saving and bug dummy ckpt cleaning.
+        if self.accelerator.is_main_process:
+            if self.args.model_type.lower() == "deepseek":
+                copy_tokenizer_files(
+                    self.args.pretrained_model_path, ["tokenizer.json", "tokenizer_config.json"], output_dir
+                )
+            else:
+                self.tokenizer.save_pretrained(output_dir)
+
+            sf = os.path.join(output_dir, "model.safetensors")
+            index_file = os.path.join(output_dir, "model.safetensors.index.json")
+            if os.path.isfile(sf) and os.path.isfile(index_file):
+                self.print(f"Remove bug dummy ckpt {sf}")
+                os.remove(sf)
+
+        # save latest info
+        if self.accelerator.is_main_process:
+            latest = {
+                "latest_ckpt": output_dir,
+                "lr": self.optimizer.param_groups[0]["lr"],
+            }
+            with open(os.path.join(self.args.output_dir, "latest"), "w") as f:
+                json.dump(latest, f, indent=2)
+
+            logger.info(
+                f"[CHECKPOINT][complete_steps={completed_steps}], states {output_dir} saved, latest: {latest}",
+                main_process_only=True,
+            )
+        self.accelerator.wait_for_everyone()
+
+    def accelerate_monitor(
+        self,
+        reduce_loss,
+        reduce_task_loss,
+        reduce_task_exist,
+        completed_steps,
+        coba_status=None,
+    ):
+        """
+        gather reduce_loss and reduce_task_loss from all N devices.
+        train logging and tensorboarding.
+        """
+        # gather reduce_loss and reduce_task_loss from all N devices
+        reduce_losses = self.accelerator.gather(reduce_loss).detach().float()
+        reduce_task_losses = self.accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK))
+        reduce_task_exists = self.accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK))
+
+        # get train loss and per-task train loss
+        train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps)
+        # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (self.args.log_interval * self.args.gradient_accumulation_steps)
+        train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0)
+
+        # logging and writing tensorboard
+        logger.info(
+            f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]"
+            f"[train_task_loss={self.format_tensor(train_task_loss, 4)}]"
+            f"[gather shape={list(reduce_losses.shape)}]"
+            f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]",
+            main_process_only=True,
+        )
+        if coba_status is not None:
+            if completed_steps > coba_status.coba_warmup_steps:
+                coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum(
+                    coba_status.log_per_task_weight
+                )
+            else:
+                coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK)
+            logger.info(
+                f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True
+            )
+        train_log_dict = {"Loss/train": train_loss}
+        for i in range(len(ID2TASK)):
+            train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i]
+            if coba_status is not None:
+                train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item()
+
+        if self.accelerator.is_main_process:
+            write_tensorboard(self.summary_writer, train_log_dict, completed_steps)
+
+        if coba_status is not None:
+            coba_status.log_per_task_weight = torch.zeros(len(ID2TASK))
+
+    def accelerate_evaluate(
+        self,
+        completed_steps,
+        step,
+        min_eval_loss,
+        stall_num,
+        best_step,
+    ):
+        """
+        evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices.
+        eval logging and tensorboarding.
+        """
+        losses = []
+        accumulated_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+        accumulated_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+        for valid_step, valid_batch in enumerate(self.valid_dataloader):
+            with torch.no_grad():
+                outputs = self.model(
+                    input_ids=valid_batch["input_ids"],
+                    attention_mask=valid_batch["attention_mask"],
+                    position_ids=valid_batch["position_ids"],
+                    return_dict=True,
+                )
+
+                loss, task_loss, _ = loss_func_mft(
+                    outputs=outputs,
+                    labels=valid_batch["labels"],
+                    task_mask=valid_batch["task_mask"],
+                    task_id=valid_batch["task_id"],
+                    weighted_loss_mode=self.args.weighted_loss_mode,
+                    loss_mask=valid_batch["loss_mask"],
+                    task_weights=self.args.task_weights,
+                )
+
+                losses.append(self.accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size)))
+                accumulated_task_loss += task_loss.detach().float()
+                accumulated_task_exist += (task_loss != 0.0).detach().float()
+
+        self.accelerator.wait_for_everyone()
+        valid_batch_num = len(losses)
+        gathered_size = losses[0].shape
+        losses = torch.cat(losses)
+        # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK))
+        task_losses = self.accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK))
+        task_exists = self.accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK))
+
+        try:
+            eval_loss = torch.mean(losses)
+            # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num
+            eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0)
+            if eval_loss <= min_eval_loss:
+                min_eval_loss = eval_loss
+                stall_num = 0
+                best_step = completed_steps
+            else:
+                stall_num += 1
+            perplexity = math.exp(eval_loss)
+        except OverflowError:
+            perplexity = float("inf")
+
+        logger.info(
+            f"[EVAL][completed_steps={completed_steps}]"
+            f"[eval_loss={eval_loss:.6f}][eval_task_loss={self.format_tensor(eval_task_loss, 4)}]"
+            f"[perplexity={perplexity:.4f}][valid_batch_num={valid_batch_num}]"
+            f"[gather_size={list(gathered_size)}]",
+            main_process_only=True,
+        )
+        eval_log_dict = {
+            "Loss/valid": eval_loss,
+            "Perplexity/valid": perplexity,
+            "Epochs": round(completed_steps / self.num_update_steps_per_epoch, 2),
+        }
+        for i in range(len(ID2TASK)):
+            eval_log_dict[f"{ID2TASK[i]}_loss/valid"] = eval_task_loss[i]
+
+        if self.accelerator.is_main_process:
+            write_tensorboard(self.summary_writer, eval_log_dict, completed_steps)
+
+        return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step
+
+    def accelerate_train(self):
+        # Train!
+        if self.args.seed is not None:
+            set_seed(self.args.seed)
+
+        global_batch_size = (
+            self.args.per_device_train_batch_size
+            * self.accelerator.num_processes
+            * self.args.gradient_accumulation_steps
+        )
+        logger.info("************************************** Running training ****************************************")
+        logger.info(f"  Num examples = {self.total_train_dataset_size}")
+        logger.info(f"  Num Epochs = {self.args.num_train_epochs}")
+        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
+        logger.info(f"  Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}")
+        logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
+        logger.info(f"  Total optimization(update/completed) steps = {self.args.max_train_steps}")
+        logger.info(f"  Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_train_epochs}")
+        logger.info("************************************************************************************************")
+
+        # Only show the progress bar once on each machine.
+        progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process)
+
+        # set starting_epoch, completed_steps and resume_step of train_dataloader
+        completed_steps = 0
+        starting_epoch = 0
+        resume_step = None
+
+        if self.args.resume_from_checkpoint:
+            self.accelerator.load_state(self.args.resume_from_checkpoint)
+            self.accelerator.print(f"Resumed from checkpoint: {self.args.resume_from_checkpoint}")
+            path = os.path.basename(self.args.resume_from_checkpoint)
+            starting_epoch, completed_steps, resume_step = extract_epochs_and_steps(
+                path, self.num_update_steps_per_epoch, self.args.gradient_accumulation_steps
+            )
+
+        # update the progress_bar if load from checkpoint
+        progress_bar.update(completed_steps)
+
+        # monitor minimum eval_loss, stalling num, and best_step
+        min_eval_loss = float("inf")
+        stall_num = 0
+        best_step = None
+
+        # monitor train loss
+        reduce_loss = torch.tensor(0.0).to(self.model.device)
+        reduce_aux_loss = torch.tensor(0.0).to(self.model.device)
+        reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+        reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+        per_task_weight = self.args.task_weights
+
+        if self.args.weighted_loss_mode == "coba":
+            self.model.eval()
+            eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate(
+                completed_steps,
+                0,
+                min_eval_loss,
+                stall_num,
+                best_step,
+            )
+            self.model.train()
+            coba_status = CoBaStatus(
+                self.args.coba_warmup_steps,
+                self.args.coba_history_length,
+                self.args.coba_tau,
+                self.args.coba_update_interval,
+                self.args.coba_sample_valid_num,
+                self.valid_dataloader,
+            )
+            coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device)
+            coba_status.sample_valid_batch(self.model, completed_steps)
+            logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True)
+        else:
+            coba_status = None
+
+        # Training Loop!
+        for epoch in range(starting_epoch, self.args.num_train_epochs):
+            # set_epoch
+            # self.train_dataloader.set_epoch(epoch)
+
+            # if we early stop by some ckpts not converging
+            if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num:
+                break
+
+            if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
+                # We skip the first `n` batches in the dataloader when resuming from a checkpoint
+                active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step)
+            else:
+                active_dataloader = self.train_dataloader
+            tail_num = len(active_dataloader) - len(active_dataloader) % self.args.gradient_accumulation_steps
+            print(f"length of dataloader: {len(active_dataloader)}")
+
+            self.model.train()
+            # Inner Loop!
+            for step, batch in enumerate(active_dataloader):
+                if step == tail_num:
+                    break
+                with self.accelerator.accumulate(self.model):
+                    if step == 0:
+                        self.touch(batch, num_tokens=10)
+                    # forward
+                    outputs = self.model(
+                        input_ids=batch["input_ids"],
+                        attention_mask=batch["attention_mask"],
+                        position_ids=batch["position_ids"],
+                        return_dict=True,
+                    )
+
+                    if (
+                        self.args.weighted_loss_mode == "coba"
+                        and self.accelerator.sync_gradients
+                        and completed_steps % self.args.coba_update_interval == 0
+                        and completed_steps >= self.args.coba_warmup_steps
+                    ):
+                        with torch.no_grad():
+                            per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps)
+                            coba_status.log_per_task_weight += per_task_weight
+                            # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True)
+
+                    # loss
+                    loss, task_loss, _ = loss_func_mft(
+                        outputs=outputs,
+                        labels=batch["labels"],
+                        task_mask=batch["task_mask"],
+                        task_id=batch["task_id"],
+                        weighted_loss_mode=self.args.weighted_loss_mode,
+                        loss_mask=batch["loss_mask"],
+                        task_weights=per_task_weight,
+                    )
+
+                    # accelerator.print(len(outputs.router_logits), outputs.router_logits[0], outputs.router_logits[-1])
+                    # accelerator.print(batch['attention_mask'].shape, batch['attention_mask'])
+                    aux_loss = None
+                    if hasattr(self.model_config, "output_router_logits") and self.model_config.output_router_logits:
+                        if hasattr(self.model_config, "num_local_experts"):
+                            num_experts = self.model_config.num_local_experts
+                        elif hasattr(self.model_config, "num_experts"):
+                            num_experts = self.model_config.num_experts
+                        else:
+                            raise ValueError("model has no attribute num_local_experts or num_experts")
+                        aux_loss = load_balancing_loss_func(
+                            outputs.router_logits,
+                            num_experts,
+                            self.model_config.num_experts_per_tok,
+                            batch["attention_mask"],
+                        )
+                        aux_loss = self.model_config.router_aux_loss_coef * aux_loss.to(loss.device)
+                        loss += aux_loss  # make sure to reside in the same device
+
+                    # backward
+                    self.accelerator.backward(loss)
+                    # print(self.lr_scheduler.state_dict(), self.accelerator.process_index)
+                    # update(sync_gradients)
+                    self.optimizer.step()
+                    self.lr_scheduler.step()
+                    self.optimizer.zero_grad()
+                    # support args.min_lr
+                    if self.optimizer.param_groups[0]["lr"] <= self.args.min_lr:
+                        self.optimizer.param_groups[0]["lr"] = self.args.min_lr
+
+                    # accumulate resuce_loss and reduce_task_loss in a log_interval
+                    if not torch.isnan(loss):
+                        reduce_loss += loss.detach().float()
+                    if aux_loss and not torch.isnan(aux_loss):
+                        reduce_aux_loss += aux_loss.detach().float()
+                    # self.print("task loss devices: ", reduce_task_loss.device, task_loss.device)
+                    reduce_task_loss += task_loss.detach().float()
+                    reduce_task_exist += (task_loss != 0).detach().float()
+
+                    # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done.
+                    if self.accelerator.sync_gradients:
+                        if (
+                            self.args.weighted_loss_mode == "coba"
+                            and completed_steps % self.args.coba_update_interval == 0
+                            and completed_steps >= 1
+                        ):
+                            coba_status.sample_valid_batch(self.model, completed_steps)
+                            # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True)
+
+                        # progress_bar.update(1)
+                        completed_steps += 1
+                        # monitoring training process and logging and tensorboarding
+                        if completed_steps % self.args.log_interval == 0:
+                            progress_bar.update(self.args.log_interval)
+                            if reduce_aux_loss > 0.0:
+                                self.print(f"[INFO] aux_loss: {reduce_aux_loss/self.args.log_interval}")
+                            self.accelerate_monitor(
+                                reduce_loss,
+                                reduce_task_loss,
+                                reduce_task_exist,
+                                completed_steps,
+                                coba_status,
+                            )
+                            # reset reduce_loss
+                            reduce_loss = torch.tensor(0.0).to(self.model.device)
+                            reduce_aux_loss = torch.tensor(0.0).to(self.model.device)
+                            reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+                            reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+
+                        # steps checkpointing
+                        if self.args.checkpointing_steps and completed_steps % self.args.checkpointing_steps == 0:
+                            output_dir = f"step_{completed_steps}"
+                            if self.args.output_dir is not None:
+                                output_dir = os.path.join(self.args.output_dir, output_dir)
+                            self.accelerate_saving_states(output_dir, completed_steps)
+
+                        # steps evaluation
+                        if completed_steps % self.args.evaluation_steps == 0 and self.valid_dataloader:
+                            self.model.eval()
+                            eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate(
+                                completed_steps,
+                                step,
+                                min_eval_loss,
+                                stall_num,
+                                best_step,
+                            )
+                            self.model.train()
+
+                            # delete ckpts over args.saving_limit
+                            if self.accelerator.is_main_process and self.args.saving_limit:
+                                delete_ckpts_over_limits(self.args.output_dir, self.args.saving_limit, best_step)
+
+                            # early stoppin when stalling more than args.early_stopping_stall_num
+                            if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num:
+                                self.print(f"[WARNING] Early stopping at {completed_steps}")
+                                break
+
+                        if completed_steps >= self.args.max_train_steps:
+                            break
+                        self.accelerator.wait_for_everyone()
+
+            # epoch checkpointing
+            if self.args.epoch_checkpointing:
+                output_dir = f"epoch_{epoch + 1}"
+                if self.args.output_dir is not None:
+                    output_dir = os.path.join(self.args.output_dir, output_dir)
+                self.accelerate_saving_states(output_dir, completed_steps)
+
+        self.summary_writer.close()
diff --git a/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py
new file mode 100644
index 0000000..ca4347e
--- /dev/null
+++ b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py
@@ -0,0 +1,220 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+import multiprocessing
+import os
+import sys
+import random
+import time
+import tqdm
+import glob
+import json
+import numpy as np
+
+
+# 将父目录的父目录加入path 
+current_path = os.path.abspath(__file__)
+parent_dir = os.path.dirname(os.path.dirname(current_path))
+grandparent_dir = os.path.dirname(parent_dir)
+sys.path.append(grandparent_dir)
+
+from tokenizer import init_tokenizer
+from pack_encoder import PackSSTBinEncoder, load_tokenizer
+from data import indexed_dataset
+
+from threading import Semaphore
+from colorama import Fore
+import lm_fmt as lmd
+
+
+def yield_from_files(files: list, semaphore):
+    """
+    Iterator over input documents 
+
+    :param fnames: list of filenames
+    """
+    def yielder(fname, semaphore):
+        with open(fname, 'r') as f:
+            for line in f:
+                semaphore.acquire()
+                yield json.loads(line)
+
+    for fname in files:
+        semaphore.acquire()
+        yield from yielder(fname, semaphore)
+
+def yield_from_files2(fnames: list, semaphore, sample_percent):
+    """
+    Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts /
+    other compressed formats. Also filters out empty documents.
+
+    :param fnames: list of filenames
+    """
+    def yielder(fname, semaphore):
+        try:
+            sample_interval = int(1/sample_percent)
+            for f in filter(lambda x: x, lmd.Reader(fname).stream_data(key=None)):
+                rand_value = random.randint(1, sample_interval*100)
+                if rand_value % sample_interval != 0:
+                    continue
+                semaphore.acquire()
+            
+                #rand_value = random.randint(1, sample_interval*100)
+                #if rand_value % sample_interval != 0:
+                #    yield None
+
+                yield f
+        except Exception as e:
+            print('####Exception:', e.args)
+            yield None
+
+    for fname in fnames:
+        semaphore.acquire()
+
+        yield from yielder(fname, semaphore)
+
+
+def print_example_doc(input_ids, tokenizer):
+    print(Fore.YELLOW + f'INPUT IDS len: {len(input_ids)}')
+    print(Fore.BLUE + f'INPUT IDS:\n {input_ids}\n\n')
+
+    print(Fore.RED + f'DETOKENIZED INPUT:\n{tokenizer.decode(input_ids)}')
+
+
+def core_process(encoded_docs, semaphore, seq_length, tokenizer, encoder, builder, output_idx_file):
+    """
+    core of Data Pack SFT processing
+    """
+    input_ids_key = 'input_ids'
+
+    proc_start = time.time()
+    total_bytes_processed = 0
+    pbar = tqdm.tqdm()
+    sentence_droped = 0
+    loss_token_cnt = 0
+
+    print("PRINT BEFORE STREAM PROCESS DATA")
+
+    print_example_count = 0  
+    for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
+        total_bytes_processed += bytes_processed
+
+        # release semaphore so `yield_from_files` can add another file to the buffer
+        semaphore.release()
+
+        # add each tokenized document / sentence,
+        # For sft, each document has only one sample
+        input_ids_sentence = doc[input_ids_key][0]
+        if len(input_ids_sentence) < 1:
+            sentence_droped += 1
+            continue
+
+        builder.add_item(np.array(input_ids_sentence, dtype=builder.dtype))
+        builder.end_document()
+        #builder.finalize_without_close(output_idx_file)
+        #builder.add_item_and_end_document_and_finalize(np.array(input_ids_sentence, dtype=builder.dtype), output_idx_file)
+
+        # print the first packed sample as example
+        if print_example_count < 1:
+            print_example_doc(input_ids_sentence, tokenizer)
+            print_example_count += 1
+
+        # log progress
+        if i % 100 == 0:
+            current = time.time()
+            elapsed = current - proc_start
+            mbs = total_bytes_processed / elapsed / 1024 / 1024
+            pbar.set_description(
+                f"Processed {i} documents ({i / elapsed} docs/s, {mbs} MB/s)."
+            )
+            if i != 0:
+                pbar.update(100)
+
+    # 尾部处理
+    builder.finalize(output_idx_file)
+
+    print(Fore.RED + "\ndroped docs: {}".format(sentence_droped))
+
+
+def process_dataset(dataset_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent):
+    """
+    Re-organize samples in the given data path into a Data Pack file.
+    """
+
+    # get all jsonl files and corresponding reading handler
+    files = glob.glob(os.path.join(dataset_path, '**/*.jsonl'), recursive=True)
+
+    # build a semaphore object to stop `yield_from_files` from getting ahead 
+    # of encoder.encode and hence building up memory
+    semaphore = Semaphore(1000 + parallel_num)
+
+    # build sample iterator
+    sample_iterator = yield_from_files2(files, semaphore, sample_percent)  
+
+    # load tokenizer
+    # tokenizer = load_tokenizer(model_path, tokenizer_type)
+    tokenizer = init_tokenizer(model_path)
+    print('TOKEN of id=2:', tokenizer.convert_ids_to_tokens(2))
+    print('ID of :', tokenizer.convert_tokens_to_ids(''))
+    print('TOKEN of id=0:', tokenizer.convert_ids_to_tokens(0))
+    print('ID of :', tokenizer.convert_tokens_to_ids(''))
+
+    # init encoder
+    encoder = PackSSTBinEncoder(seq_length, model_path)
+
+    # create writer builder
+    key = "input_ids"
+    output_prefix = os.path.join(output_path, dataset_name)
+    output_bin_file = "{}_{}.bin".format(
+        output_prefix, key
+    )
+    output_idx_file = "{}_{}.idx".format(
+        output_prefix, key
+    )
+    builder = indexed_dataset.make_builder(
+        output_bin_file,
+        impl="mmap",
+        vocab_size=tokenizer.vocab_size,
+    )
+
+    if parallel_num > 1:
+        pool = multiprocessing.Pool(parallel_num, initializer=encoder.initializer)
+        encoded_docs = pool.imap(encoder.encode, sample_iterator, chunksize=32)
+    else:
+        encoder.initializer()
+        encoded_docs = (encoder.encode(doc) for doc in sample_iterator)
+
+    if dataset_name is None:
+        dataset_path = dataset_path[:-1] if dataset_path.endswith(os.path.sep) else dataset_path
+        dataset_name = dataset_path.split(os.path.sep)[-1]
+
+    core_process(encoded_docs, semaphore, seq_length, tokenizer, encoder, builder, output_idx_file)
+
+
+def main(data_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent):
+    """
+    Entry
+    """
+
+    process_dataset(data_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Generate a packed jsonl file in the Data Pack SFT way.")
+    parser.add_argument('--model-path', type=str, help='Path of a pretrained model which contains tokenizer-related files.')
+    parser.add_argument('--parallel', type=int, default=1, help='The num of parallel processing.')
+    parser.add_argument('--output-path', type=str, help='Path to store the genered result file.')
+    parser.add_argument('--data-path', type=str, default=None, help='Path of files to be processed')
+    parser.add_argument('--seq-length', type=int, default=4096, help='The max input length (i.e. the max number of tokens in a sample)')
+    # parser.add_argument('--eod-token-id', type=int, default=2, help='EOD token id')
+    # parser.add_argument('--pad-token-id', type=int, default=0, help='PAD token id')
+    # parser.add_argument('--tokenizer-type', type=str, choices=["LLAMATokenizer", None], default=None, help="What type of tokenizer to use. Default is None.")
+    parser.add_argument('--dataset-name', type=str, default=None, help='The generated result dataset name. The folder name will be token by default.')
+    parser.add_argument('--sample-percent', type=float, default=1.0, help='Sample percentage')
+
+    args = parser.parse_args()
+    print('ARGS\n', '\n'.join([str(key) + ':' + str(value) for key,value in vars(args).items()]))
+
+    random.seed(9999)
+
+    main(args.data_path, args.output_path, args.model_path, args.parallel, args.seq_length, args.dataset_name, args.sample_percent)
diff --git a/mft_peft_hf/src/data/tokenization/lm_dataformat.py b/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py
similarity index 86%
rename from mft_peft_hf/src/data/tokenization/lm_dataformat.py
rename to mftcoder_accelerate/src/offline_tokenization/lm_fmt.py
index f0e121b..c922859 100644
--- a/mft_peft_hf/src/data/tokenization/lm_dataformat.py
+++ b/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py
@@ -108,59 +108,33 @@ def handle_jsonl(jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key='
             assert not get_meta
             yield ob
             continue
-        
-        if isinstance(key, str):
-            text = ob.get(key)
-            if autojoin_paragraphs and isinstance(text, list):
-                text = para_joiner.join(text)
 
-            if get_meta:
-                yield text, (ob['meta'] if 'meta' in ob else {})
-            else:
-                yield text
-        elif isinstance(key, list):
-
-            task = ob.get(key[0], '') 
-            src_language = ob.get(key[1], '')
-            src_code = ob.get(key[2], '')
-            tgt_language = ob.get(key[3], '')
-            tgt_code = ob.get(key[4], '')
-            sql = ob.get(key[5], '')
-            prompt_in = ob.get(key[6], '')
-            content_in = ob.get(key[7], '')
-            bad_content_in = ob.get(key[8], '')
-
-            if task:
-                task = "task: " + task
-            if src_language:
-                src_language = "language: " + src_language
-            if sql:
-                sql = sql.strip() + '\n'
-                task = "task: text to sql\n"
-                src_language = 'language: text\n'
-                tgt_language = "language: sql\n"
-                prompt_in = prompt_in.strip() + '\n' 
-            elif tgt_language:
-                tgt_language =  "language: " + tgt_language
-
-            prompt = task + src_language + src_code + prompt_in + tgt_language
-            content =  tgt_code + sql + content_in
-            bad_content = bad_content_in
-
-            yield (prompt, content, bad_content)
+        if key is None:
+            yield ob
+            continue
+
+        text = ob[key]
+
+        if autojoin_paragraphs and isinstance(text, list):
+            text = para_joiner.join(text)
+
+        if get_meta:
+            yield text, (ob['meta'] if 'meta' in ob else {})
+        else:
+            yield text
 
 
 class Reader:
     def __init__(self, in_path):
         self.in_path = in_path
     
-    def stream_data(self, get_meta=False, threaded=False, key=['prompt', 'content', 'bad_content']):
+    def stream_data(self, get_meta=False, threaded=False, key=None):
         if not threaded:
-            yield from self._stream_data(get_meta, jsonl_key=key)
+            yield from self._stream_data(get_meta, key=key)
             return
 
         q = mp.Queue(1000)
-        p = mp.Process(target=self._stream_data_threaded, args=(q, get_meta))
+        p = mp.Process(target=self._stream_data_threaded, args=(q, get_meta), kwargs={"key": key})
         p.start()
         while p.is_alive():
             res = q.get()
@@ -172,7 +146,7 @@ def _stream_data_threaded(self, q, get_meta=False):
             q.put(data)
         q.put(None)
 
-    def _stream_data(self, get_meta=False, jsonl_key="text"):
+    def _stream_data(self, get_meta=False, key="text"):
         self.f_name = ""
         files = listdir_or_file(self.in_path)
         if not files:
@@ -192,11 +166,11 @@ def _stream_data(self, get_meta=False, jsonl_key="text"):
 
                 yield from self.read_dat(f)
             elif f.endswith('.jsonl'):
-                yield from self.read_jsonl(f, get_meta, key=jsonl_key)
+                yield from self.read_jsonl(f, get_meta, key=key)
             elif f.endswith('.jsonl.zst'):
-                yield from self.read_jsonl_zst(f, get_meta, key=jsonl_key)
+                yield from self.read_jsonl_zst(f, get_meta, key=key)
             elif f.endswith('.jsonl.zst.tar'):
-                yield from self.read_jsonl_tar(f, get_meta, jsonl_key=key)
+                yield from self.read_jsonl_tar(f, get_meta, key=key)
             elif f.endswith('.json.zst'):
                 assert not get_meta
 
@@ -383,4 +357,4 @@ def commit(self):
             fh.write(cdata)
 
         self.i += 1
-        self.data = []
\ No newline at end of file
+        self.data = []
diff --git a/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py b/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py
new file mode 100644
index 0000000..0678e27
--- /dev/null
+++ b/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py
@@ -0,0 +1,335 @@
+from transformers import AutoTokenizer
+from tokenizer import init_tokenizer
+
+
+def load_tokenizer(model_path, tokenizer_type=None):
+    """
+    Load tokenizer from the given 
+    """
+
+    def load_tokenizer_manual(model_path, tokenizer_type):
+        """
+        Load tokenizer by the concrete Tokenizer class instead of AutoTokenizer
+        """
+        try:
+            if tokenizer_type.lower() == "LlamaTokenizer".lower():
+                return LlamaTokenizer.from_pretrained(model_path)
+
+            raise Exception(f"Unsupported tokenizer type {tokenizer_type}")
+        except:
+            raise Exception(f"Unable to load tokenizer {tokenizer_type} from the given path: {model_path}")
+
+    def load_tokenizer_auto(model_path):
+        """
+        Load tokenizer from the given path by HuggingFace AutoTokenizer
+        """
+        try:
+            # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)  # support CodeLlama
+            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+            return tokenizer
+        except:
+            raise Exception(
+                f'Unable to load tokenizer from the given path: {model_path} using auto mode.\nPlease specify the tokenizer type with the command argument "--tokenizer-type" and retry.'
+            )
+
+    # First, try to load tokenizer by huggingface AutoTokenizer, If fail, try another manual way
+    try:
+        return load_tokenizer_auto(model_path)
+    except Exception as e:
+        print(str(e))
+        if tokenizer_type is not None:
+            try:
+                tokenizer = load_tokenizer_manual(model_path, tokenizer_type)
+                return tokenizer
+            except Exception as ee:
+                raise ee
+
+
+class PackPFTEncoder:
+    """
+    A sample of this format will be:
+        <|role_start|>system<|role_end|> content of system_1
+        <|role_start|>human<|role_end|> content of human_1
+        <|role_start|>bot<|role_end|> content of bot_1
+        <|endoftext|>
+        <|role_start|>system<|role_end|> content of system_2
+        <|role_start|>human<|role_end|> content of human_2
+        <|role_start|>bot<|role_end|> content of bot_2
+        <|endoftext|>
+        <|role_start|>human<|role_end|> content of human_3
+        <|role_start|>bot<|role_end|> content of bot_3
+        <|endoftext|>
+        ....
+        <|role_start|>human<|role_end|> content of human_n
+        <|role_start|>bot<|role_end|> content of bot_n
+        <|endoftext|>
+        
+        <|pad|><|pad|>...<|pad|>
+
+    system part is optional, i.e. '<|role_start|>system<|role_end|> content of system_i'
+    """
+
+    def __init__(self, seq_length, eod_token_id, pad_token_id, role_start_tag, role_end_tag, mode="pft"):
+        self.mode = mode
+        self.seq_length = seq_length
+        self.eod_token_id = eod_token_id
+        self.pad_token_id = pad_token_id
+        self.role_start_tag = role_start_tag
+        self.role_end_tag = role_end_tag
+
+    def initializer(self, model_path, tokenizer_type=None):
+        # Use Encoder class as a container for global data
+        assert model_path is not None
+        self.tokenizer = load_tokenizer(model_path, tokenizer_type)
+
+    def encode(self, item):
+        encode_res = {
+            "input_ids": [],
+        }
+
+        item_len = sum([len(x["content"]) for x in item["chat_rounds"]])
+        for token_res in self.tokenize_chat_prompt(item):
+            for k, v in token_res.items():
+                encode_res[k].append(v)
+        return encode_res, item_len
+
+    def tokenize_chat_prompt(self, item):
+        # role_start_marker = self.tokenizer.encode(self.role_start_tag, add_special_tokens=False)
+        # role_end_marker = self.tokenizer.encode(self.role_end_tag, add_special_tokens=False)
+        end_marker = [self.eod_token_id]
+
+        input_ids = []
+        raw_input = ""
+        # loss_mask = []
+        for chat_round in item["chat_rounds"]:
+            role = chat_round["role"].strip()
+            # skip system prompt
+            # if role == 'system':
+            #    continue
+
+            content = chat_round["content"]
+            content = content if content.endswith("\n") else f"{content}\n"
+            text = f"{self.role_start_tag}{role}{self.role_end_tag}{content}"
+            chat_input_ids = self.tokenizer.encode(text, add_special_tokens=False)
+
+            if role != "bot":
+                chat_input_ids = chat_input_ids
+            else:
+                chat_input_ids = chat_input_ids + end_marker
+
+            input_ids += chat_input_ids
+
+        # if this sample's length is more than the specified max length, drop it
+        # here, we don't add padding tokens for a single sample, however, we will append padding tokens for a combinated samaple
+        if len(input_ids) > self.seq_length:
+            yield {}
+        else:
+            yield {"input_ids": input_ids}
+
+    def padding(self, key, data):
+        assert len(data) <= self.seq_length, f"padding sequence: {len(data)} > {self.seq_length}"
+        if key == "input_ids":
+            return data + [self.pad_token_id] * (self.seq_length - len(data))
+
+        if key == "loss_mask":
+            return data + [0] * (self.seq_length - len(data))
+
+        raise Exception("Should not reach here. There must be something wrong.")
+
+
+class PackSFTEncoder:
+    """
+    A sample of this format will be:
+        <|role_start|>system<|role_end|> content of system_1
+        <|role_start|>human<|role_end|> content of human_1
+        <|role_start|>bot<|role_end|> content of bot_1
+        <|endoftext|>
+        <|role_start|>system<|role_end|> content of system_2
+        <|role_start|>human<|role_end|> content of human_2
+        <|role_start|>bot<|role_end|> content of bot_2
+        <|endoftext|>
+        <|role_start|>human<|role_end|> content of human_3
+        <|role_start|>bot<|role_end|> content of bot_3
+        <|endoftext|>
+        ....
+        <|role_start|>human<|role_end|> content of human_n
+        <|role_start|>bot<|role_end|> content of bot_n
+        <|endoftext|>
+        
+        <|pad|><|pad|>...<|pad|>
+
+    system part is optional, i.e. '<|role_start|>system<|role_end|> content of system_i'
+    """
+
+    def __init__(self, seq_length, eod_token, role_start_tag, role_end_tag, mode="sft"):
+        self.mode = mode
+        self.seq_length = seq_length
+        self.eod_token = eod_token
+        self.role_start_tag = role_start_tag
+        self.role_end_tag = role_end_tag
+
+    def initializer(self, model_path, tokenizer_type=None):
+        # Use Encoder class as a container for global data
+        assert model_path is not None
+        self.tokenizer = load_tokenizer(
+            model_path, tokenizer_type
+        )  # AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+    def encode(self, item):
+        encode_res = {"input_ids": [], "raw_input": []}
+
+        item_len = sum([len(x["content"]) for x in item["chat_rounds"]])
+        for token_res in self.tokenize_chat_prompt(item):
+            for k, v in token_res.items():
+                encode_res[k].append(v)
+        return encode_res, item_len
+
+    def tokenize_chat_prompt(self, item):
+        role_start_marker = self.tokenizer.encode(self.role_start_tag, add_special_tokens=False)
+        role_end_marker = self.tokenizer.encode(self.role_end_tag, add_special_tokens=False)
+        end_marker = [self.tokenizer.convert_tokens_to_ids(self.eod_token)]
+
+        input_ids = []
+        raw_input = ""
+        # loss_mask = []
+        for chat_round in item["chat_rounds"]:
+            role = chat_round["role"]
+            content = chat_round["content"]
+            content = content if content.endswith("\n") else f"{content}\n"
+            chat_input_ids = self.tokenizer.encode(content, add_special_tokens=False)
+            role_input_ids = self.tokenizer.encode(role, add_special_tokens=False)
+            role_raw_input = ""
+
+            if role != "bot":
+                # chat_loss_mask = [0] * len(role_start_marker) + [0] * len(role_input_ids) + [0] * len(role_end_marker) + [0] * len(chat_input_ids)
+                chat_input_ids = role_start_marker + role_input_ids + role_end_marker + chat_input_ids
+                role_raw_input = ROLE_START_MARKER + role + ROLE_END_MARKER + content
+            elif role == "human":
+                # chat_loss_mask = [0] * len(role_start_marker) + [0] * len(role_input_ids) + [0] * len(role_end_marker) + [1] * len(chat_input_ids) + [1] * len(end_marker)
+                chat_input_ids = role_start_marker + role_input_ids + role_end_marker + chat_input_ids + end_marker
+                role_raw_input = ROLE_START_MARKER + role + ROLE_END_MARKER + content + self.eod_token
+
+            input_ids += chat_input_ids
+            raw_input += role_raw_input
+            # loss_mask += chat_loss_mask
+
+        # assert len(input_ids) == len(loss_mask)
+
+        # if this sample's length is more than the specified max length, drop it
+        # here, we don't add padding tokens for a single sample, however, we will append padding tokens for a combinated samaple
+        if len(input_ids) > self.seq_length:
+            yield {}
+        else:
+            yield {
+                "input_ids": input_ids,
+                "raw_input": raw_input,
+                # "loss_mask": loss_mask
+            }
+
+    def padding(self, key, data, pad_token_id):
+        assert len(data) <= self.seq_length, f"padding sequence: {len(data)} > {self.seq_length}"
+        if key == "input_ids":
+            return data + [pad_token_id] * (self.seq_length - len(data))
+
+        if key == "loss_mask":
+            return data + [0] * (self.seq_length - len(data))
+
+        raise Exception("Should not reach here. There must be something wrong.")
+
+
+class PackSSTBinEncoder:
+    """
+    A sample of this format will be:
+        content of sample_1
+        content of sample_2
+        ...
+        content of sample_n
+        <|pad|><|pad|>...<|pad|>
+    """
+
+    def __init__(self, seq_length, model_path):
+        self.seq_length = seq_length
+        self.model_path = model_path
+
+    def initializer(self):
+        # Use Encoder class as a container for global data
+        assert self.model_path is not None
+        # self.tokenizer = load_tokenizer(model_path, tokenizer_type) #AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+        # PackSSTBinEncoder.tokenizer = load_tokenizer(self.model_path, self.tokenizer_type)
+        PackSSTBinEncoder.tokenizer = init_tokenizer(self.model_path)
+
+    def _encode_content(self, item, encode_res):
+        if "content" in item:
+            content = item["content"]
+        else:
+            content = item["text"]
+
+        item_len = len(content)
+
+        input_ids = self.tokenize_string(content)
+        encode_res["input_ids"].append(input_ids)
+
+        return encode_res, item_len
+
+    def _encode_chatml(self, item, encode_res):
+        input_ids = []
+        item_len = 0
+        one_round_content = ""
+        for i in range(len(item["chat_rounds"])):
+            chat_round = item["chat_rounds"][i]
+            role = chat_round["role"]
+            content = chat_round["content"]
+            content = content if content.endswith("\n") else f"{content}\n"
+            if role.lower() == "system":
+                continue
+            if role.lower() == "human":
+                one_round_content = content
+            else:
+                one_round_content += content
+                input_ids += self.tokenize_string(one_round_content)
+                item_len += len(one_round_content)
+
+        encode_res["input_ids"].append(input_ids)
+
+        return encode_res, item_len
+
+    def encode(self, item):
+        encode_res = {
+            "input_ids": [],
+        }
+
+        try:
+            if item is None:
+                encode_res["input_ids"].append([])
+                return encode_res, 0
+
+            if "content" in item or "text" in item:
+                return self._encode_content(item, encode_res)
+
+            if "chat_rounds" in item:
+                return self._encode_chatml(item, encode_res)
+        except Exception as e:
+            print("####JSON Exception", e, str(item))
+            encode_res["input_ids"].append([])
+            return encode_res, 0
+
+        raise Exception("Unsupported Format!")
+
+    def tokenize_string(self, text):
+        end_marker = [PackSSTBinEncoder.tokenizer.eos_token_id]
+
+        input_ids = []
+        try:
+            input_ids = PackSSTBinEncoder.tokenizer.encode(text, add_special_tokens=False)
+            input_ids = input_ids + end_marker
+            return input_ids
+        except Exception as e:
+            print("####Tokenization Exception:", e, text)
+            return []
+        except BaseException as e:
+            print("####Tokenization BaseException:", e, "Length of text", len(text))
+            return []
+
+    def padding(self, data, pad_token_id):
+        assert len(data) <= self.seq_length, f"padding sequence: {len(data)} > {self.seq_length}"
+        return data + [pad_token_id] * (self.seq_length - len(data))
diff --git a/mftcoder_accelerate/src/offline_tokenization/writer.py b/mftcoder_accelerate/src/offline_tokenization/writer.py
new file mode 100644
index 0000000..ab526a7
--- /dev/null
+++ b/mftcoder_accelerate/src/offline_tokenization/writer.py
@@ -0,0 +1,42 @@
+
+import threading
+import fcntl
+import json
+
+class JSONLWriter():
+    """
+    A writer used to save jsonl lines into a file.
+    """
+    def __init__(self, output_path, dataset_name):
+        self.output_path = output_path
+        self.out_file = open(output_path, 'w')
+        self.cache = []
+        self.cache_size = 4096
+        self.dataset_name = dataset_name
+        self.index = 0
+
+    def pack_into_jsonl(self, line_text):
+        new_item = {
+            "data_name": self.dataset_name,
+            "id": self.index,
+            "content": line_text
+        }
+
+        return new_item
+
+
+    def add_item(self, line_text):
+        if len(self.cache) >= self.cache_size:
+            self.flush()
+        
+        item = self.pack_into_jsonl(line_text)
+        self.cache.append(json.dumps(item))
+        self.index += 1
+
+    
+    def flush(self):
+        content = '\n'.join(self.cache)
+        fcntl.flock(self.out_file, fcntl.LOCK_EX)
+        self.out_file.write(f'{content}\n')
+        fcntl.flock(self.out_file, fcntl.LOCK_UN)
+        self.cache = [] 
diff --git a/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py
new file mode 100644
index 0000000..26f8ec1
--- /dev/null
+++ b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py
@@ -0,0 +1,103 @@
+"""
+# @author Chaoyu Chen
+# @date 2023/10/19
+
+Merge base and lora adaptor
+"""
+
+import os
+import sys
+import time
+import shutil
+import argparse
+from typing import List
+import torch
+import transformers
+from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
+from peft import LoraConfig, get_peft_model
+from peft import PeftModel
+
+# insert src as import path
+current_path = os.path.abspath(__file__)
+parent_dir = os.path.dirname(os.path.dirname(current_path))
+sys.path.insert(0, parent_dir)
+print("In merge_base_and_lora_to_hf.py, sys path:", sys.path)
+
+from tokenizer import init_tokenizer
+
+
+def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str):
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    for filename in files_list:
+
+        src_file = os.path.join(mode_path, filename)
+
+        if os.path.exists(src_file):
+            dest_file = os.path.join(save_path, filename)
+
+            shutil.copy(src_file, dest_file)
+            print(f"Copied {filename} to {save_path}")
+        else:
+            print(f"File {filename} does not exist in {mode_path}")
+
+
+if __name__ == "__main__":
+
+    # arguments
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--base_model_or_path", type=str, default=None)
+    parser.add_argument("--adaptor_path", type=str, default=None)
+    parser.add_argument("--model_type", type=str, default=None)
+    parser.add_argument("--merged_output_path", type=str, default=None)
+    args = parser.parse_args()
+
+    model_path = args.base_model_or_path
+    lora_adapter = args.adaptor_path
+    model_type = args.model_type
+    save_path = args.merged_output_path
+
+    t0 = time.time()
+
+    tokenizer = init_tokenizer(args.base_model_or_path)
+
+    base_model = AutoModelForCausalLM.from_pretrained(
+        model_path,
+        trust_remote_code=True,
+        torch_dtype=torch.bfloat16,
+        # torch_dtype=torch.float32,
+        return_dict=True,
+        device_map="auto",
+    )
+    print("--------------------------------------Base Model--------------------------------------------")
+    print(base_model)
+    print("--------------------------------------------------------------------------------------------")
+
+    print("-----------------------------------Base Model Config----------------------------------------")
+    print(base_model.config)
+    print("--------------------------------------------------------------------------------------------")
+
+    # merge, save model and tokenizer
+    model_to_merge = PeftModel.from_pretrained(base_model, lora_adapter)
+    merged_model = model_to_merge.merge_and_unload()
+    # merged_model.to(torch.bfloat16)
+
+    print("---------------------------------Merged Model Config----------------------------------------")
+    print(merged_model.config)
+    print("--------------------------------------------------------------------------------------------")
+    merged_model.save_pretrained(save_path)
+
+    print("-------------------------------------Tokenizer----------------------------------------------")
+    print(tokenizer)
+    print("--------------------------------------------------------------------------------------------")
+    if model_type.lower() == "deepseek":
+        copy_tokenizer_files(
+            model_path,
+            ["tokenizer.model", "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"],
+            save_path,
+        )
+    else:
+        tokenizer.save_pretrained(save_path)
+
+    print(f"Merge finised: {save_path} saved, Cost {time.time() - t0:.2f}s")
diff --git a/mftcoder_accelerate/src/pefts/mft_accelerate.py b/mftcoder_accelerate/src/pefts/mft_accelerate.py
new file mode 100644
index 0000000..0a0d42a
--- /dev/null
+++ b/mftcoder_accelerate/src/pefts/mft_accelerate.py
@@ -0,0 +1,571 @@
+"""
+# @author Chaoyu Chen
+# @date 2024/10/24
+# @module mft_accelerate.py
+
+Accelerate + DeepSpeed/FSDP + QLoRA/LoRA/Full + Multi-task Finetuning
+
+Entry
+"""
+
+import os
+import sys
+import argparse
+import math
+import logging
+import json
+import time
+from tqdm.auto import tqdm
+import transformers
+import numpy as np
+import torch
+from torch import nn
+from dataclasses import dataclass
+from datasets import Dataset
+import datasets
+from torch.utils.data import DataLoader
+from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
+
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    get_linear_schedule_with_warmup,
+    set_seed,
+    BitsAndBytesConfig,
+    get_scheduler,
+)
+from peft import (
+    LoraConfig,
+    TaskType,
+    get_peft_model,
+    prepare_model_for_kbit_training,
+    PeftModel,
+)
+from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration
+from accelerate.logging import get_logger
+from datetime import timedelta
+from accelerate.utils import InitProcessGroupKwargs
+from transformers.optimization import Adafactor
+
+# insert src as import path
+current_path = os.path.abspath(__file__)
+parent_dir = os.path.dirname(os.path.dirname(current_path))
+sys.path.insert(0, parent_dir)
+print("In mft_accelerate.py, sys path:", sys.path)
+
+from tokenizer import build_tokenizer
+from data.multi_task_dataset import load_dataset_from_jsonl, compile_helper
+from data.data_utils import load_dataset_from_bin
+from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK
+
+from pefts.mft_trainer import MftTrainer
+from pefts.mft_arguments import MftTrainArgs
+from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS
+
+
+logger = get_logger(__name__)
+
+
+def get_task_mask(args, task_id):
+    task_num = len(TASK2ID)
+    task_mask = torch.zeros(task_id.shape[0], task_num)
+    task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1
+
+    return task_mask
+
+
+def get_attention_mask_and_position_ids(data):
+    """Build masks and position id for left to right model."""
+
+    # Extract batch size and sequence length.
+    batch_size, seq_length = data.size()
+
+    attention_mask = torch.ones((batch_size, seq_length), device=data.device)
+
+    # Position ids.
+    position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
+    position_ids = position_ids.unsqueeze(0).expand_as(data).clone()
+
+    return attention_mask, position_ids
+
+
+@dataclass
+class DataCollatorForMFTDataset(object):
+    args: None
+
+    def __call__(self, instances):
+        (input_ids, loss_mask, weights, task_id) = tuple(
+            [instance.get(key, None) for instance in instances]
+            for key in ("input_ids", "loss_mask", "weight", "task_id")
+        )
+
+        result_batch = {}
+        """
+        outputs = model(
+                    input_ids=batch['input_ids'],
+                    attention_mask=batch['attention_mask'],
+                    # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']),
+                    # labels=(batch['labels'], batch['loss_mask']),
+                    position_ids=batch['position_ids'])
+        """
+
+        # if loss_mask is not None:
+        loss_mask = torch.tensor(np.array(loss_mask)).long()
+        last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1)
+        if self.args.use_dynamic_padding:
+            # get last non-padding position
+            max_pos = last_one_pos.max().item() + 1
+        else:
+            max_pos = loss_mask.shape[-1]
+
+        if self.args.tokenize_mode == "sst" and self.args.padding_mode == "pack":
+            # sst + pack tokenization, remove last dirty data
+            result_batch["loss_mask"] = loss_mask.float()[:, 1 : max_pos - 1].contiguous()
+            input_ids = torch.tensor(np.array(input_ids)).long()
+            result_batch["input_ids"] = input_ids[:, : max_pos - 2].contiguous()
+            result_batch["labels"] = input_ids[:, 1 : max_pos - 1].contiguous()
+        else:
+            result_batch["loss_mask"] = loss_mask.float()[:, 1:max_pos].contiguous()
+            input_ids = torch.tensor(np.array(input_ids)).long()
+            # print(f"shape of input_ids: {input_ids.shape}")
+            result_batch["input_ids"] = input_ids[:, : max_pos - 1].contiguous()
+            result_batch["labels"] = input_ids[:, 1:max_pos].contiguous()
+
+        # Get the masks and position ids.
+        if self.args.model_type in ["mixtral", "qwen2_moe"]:
+            batch_size, seq_length = result_batch["input_ids"].shape
+            # bsz * seq_length
+            range_tensor = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1)
+            # attention_mask for padding tokens
+            attention_mask = (range_tensor <= last_one_pos.reshape(batch_size, 1)).long()
+            result_batch["attention_mask"], result_batch["position_ids"] = attention_mask, None
+        else:
+            # For decoder-only models, transformers will create them.
+            result_batch["attention_mask"], result_batch["position_ids"] = None, None
+
+        if task_id is not None:
+            task_id = torch.tensor(np.array(task_id))
+            result_batch["task_mask"] = get_task_mask(self.args, task_id)  # bsz * task_num
+            result_batch["task_id"] = task_id
+
+        return result_batch
+
+
+def pprint_args(args, accelerator):
+    # 计算所有键的最大字符串长度
+    max_key_length = max(len(str(key)) for key in vars(args).keys())
+
+    message = ""
+    message += "====" * 60 + "\n"
+    message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n"
+    message += "====" * 60 + "\n"
+    accelerator.print(message)
+    accelerator.print("GPU: {}".format(torch.cuda.current_device()))
+
+
+def prepare_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--train_config", type=str, default=None)
+
+    parser.add_argument("--data_paths", type=str, default=None)
+    parser.add_argument("--output_dir", type=str, default=None)
+    parser.add_argument("--tb_dir", type=str, default=None)
+    parser.add_argument("--pretrained_model_path", type=str, default=None)
+    parser.add_argument("--micro_batch_size", type=int, default=None)
+    parser.add_argument("--model_type", type=str, default=None)
+    parser.add_argument("--distributed_type", type=str, default="deepspeed")
+
+    parsed = parser.parse_args()
+    # get json configs
+    with open(parsed.train_config, "r") as f:
+        train_config = json.load(f)
+
+    # parse args from cofig.json
+    # args = argparse.Namespace(**train_config)
+    args = MftTrainArgs(**train_config)
+
+    # override args by cli arguments
+    if parsed.data_paths:
+        args.data_paths = parsed.data_paths
+    if parsed.output_dir:
+        args.output_dir = parsed.output_dir
+    if parsed.tb_dir:
+        args.tb_dir = parsed.tb_dir
+    if parsed.pretrained_model_path:
+        args.pretrained_model_path = parsed.pretrained_model_path
+        args.vocab_file = parsed.pretrained_model_path
+    if parsed.micro_batch_size:
+        args.per_device_train_batch_size = parsed.micro_batch_size
+        args.per_device_eval_batch_size = parsed.micro_batch_size
+    if parsed.model_type:
+        args.model_type = parsed.model_type
+
+    args.distributed_type = parsed.distributed_type
+
+    # refactor args
+
+    if args.peft_type == "qlora":
+        print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'")
+        args.quantization = "4bit"
+    else:
+        args.quantization = None
+
+    args.vocab_file = args.pretrained_model_path
+
+    args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(","))) + "]"
+
+    # generate TASK2ID, ID2TASK
+    generate_task_id(args.data_paths)
+
+    if args.weighted_loss_mode == "coba":
+        args.task_weights = [1.0] * len(ID2TASK)
+    elif args.task_weights is not None:
+        args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")]
+        assert len(args.task_weights) == len(ID2TASK), f"length of task_weights must equal to length of data_paths"
+    else:
+        args.task_weights = [1.0] * len(ID2TASK)
+
+    return args
+
+
+def main():
+    t0 = time.time()
+    os.environ["TOKENIZERS_PARALLELISM"] = "false"
+    os.environ["HF_HUB_OFFLINE"] = "false"
+    # get input args, set TASK2ID, ID2TASK, refactor args
+    args = prepare_args()
+
+    # fix randomness
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    # define accelerator
+    init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds))
+
+    if args.distributed_type and args.distributed_type.lower() == "fsdp":
+        fsdp_plugin = FullyShardedDataParallelPlugin(
+            # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
+            # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
+            limit_all_gathers=True,
+            sync_module_states=True,
+            use_orig_params=True,
+            cpu_offload=False,
+        )
+        accelerator = Accelerator(
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            fsdp_plugin=fsdp_plugin,
+            dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True),
+            kwargs_handlers=[init_process_kwargs],
+        )
+    else:
+        accelerator = Accelerator(
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True),
+            kwargs_handlers=[init_process_kwargs],
+        )
+
+    # print key infos
+    accelerator.print("In mft_accelerate.py, sys path:", sys.path)
+    accelerator.print(f"transformers.__version__: {transformers.__version__}")
+
+    # get world_size
+    args.world_size = accelerator.num_processes
+
+    # backup args
+    pprint_args(args, accelerator)
+    if accelerator.is_main_process:
+        if not os.path.exists(args.output_dir):
+            os.makedirs(args.output_dir)
+        with open(os.path.join(args.output_dir, "args.json"), "w") as f:
+            json.dump(args.dict(), f, indent=2)
+
+    # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest
+    latest = None
+    if os.path.exists(os.path.join(args.output_dir, "latest")):
+        with open(os.path.join(args.output_dir, "latest"), "r") as fl:
+            latest = json.load(fl)
+        accelerator.print(f"[INFO] Existing latest: {latest}")
+
+    if args.auto_resume and args.resume_from_checkpoint is None and latest:
+        if args.peft_type:
+            args.resume_from_checkpoint = latest["latest_ckpt"]
+        else:
+            args.resume_from_checkpoint = latest["latest_ckpt"]
+            args.pretrained_model_path = args.resume_from_checkpoint
+        args.learning_rate = latest["lr"]
+    elif args.resume_from_checkpoint and (not args.peft_type):
+        args.pretrained_model_path = args.resume_from_checkpoint
+
+    # logger
+    logging.basicConfig(
+        format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state, main_process_only=False)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_info()
+        # compile Cpp helper
+        compile_helper()
+        time.sleep(10)
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+
+    # get global_rank and local rank for current process
+    global_rank = accelerator.process_index
+    local_rank = accelerator.local_process_index
+    print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}")
+
+    # TASK2ID, ID2TASK
+    # generate_task_id(args.data_paths)
+
+    # multi task blendable dataset(sharded)
+    if args.load_raw_dataset:
+        print_rank_0("> load raw jsonl dataset")
+        train_dataset, valid_dataset = load_dataset_from_jsonl(
+            args=args, shard_data=True, world_size=args.world_size, global_rank=global_rank, local_rank=local_rank
+        )
+    else:
+        print_rank_0("> load tokenized bin dataset, refer to gpt_neox indexed dataset")
+        train_dataset, valid_dataset, _ = load_dataset_from_bin(args=args)
+
+    t1 = time.time()
+    logger.info(f"dataset loading time: {t1 - t0:.4f}")
+
+    # cuda memory
+    free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
+    max_memory = f"{free_in_GB - 2}GB"
+    n_gpus = torch.cuda.device_count()
+    max_memory = {i: max_memory for i in range(n_gpus)}
+    accelerator.print("max memory: ", max_memory, n_gpus)
+
+    # target_modules, default all-linear for all linear layers
+    if args.target_modules:
+        target_modules = args.target_modules
+    else:
+        target_modules = "all-linear"
+
+    # peft config
+    if args.peft_type:
+        peft_config = LoraConfig(
+            task_type=TaskType.CAUSAL_LM,
+            inference_mode=False,
+            r=args.lora_rank,
+            lora_alpha=args.lora_alpha,
+            lora_dropout=args.lora_dropout,
+            target_modules=target_modules,
+            bias="lora_only",
+        )
+
+    # # 是否要加入新的special tokens
+    # num_added_toks = tokenizer.tokenizer.add_special_tokens(["", ""])
+    # accelerator.print("We have added", num_added_toks, "tokens")
+    # accelerator.print(f"role marker tokens {tokenizer.convert_tokens_to_ids('')} {tokenizer.convert_tokens_to_ids('')}, resized tokenizer_size: {len(tokenizer)}")
+
+    # creating base model
+    ModelClass = MODEL_TYPES[args.model_type]
+    if args.model_type in SUPPORT_IN_TRANSFORMERS:
+        accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers")
+        model = ModelClass.from_pretrained(
+            args.pretrained_model_path,
+            attn_implementation=args.attn_implementation,
+            torch_dtype=torch.bfloat16,
+            quantization_config=(
+                BitsAndBytesConfig(
+                    load_in_4bit=(args.quantization == "4bit"),
+                    bnb_4bit_compute_dtype=torch.bfloat16,
+                    bnb_4bit_use_double_quant=True,
+                    bnb_4bit_quant_type="nf4",
+                    bnb_4bit_quant_storage=torch.bfloat16,
+                )
+                if args.quantization == "4bit"
+                else None
+            ),
+        )
+    else:
+        accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code")
+        model = ModelClass.from_pretrained(
+            args.pretrained_model_path,
+            torch_dtype=torch.bfloat16,
+            quantization_config=(
+                BitsAndBytesConfig(
+                    load_in_4bit=(args.quantization == "4bit"),
+                    bnb_4bit_compute_dtype=torch.bfloat16,
+                    bnb_4bit_use_double_quant=True,
+                    bnb_4bit_quant_type="nf4",
+                    bnb_4bit_quant_storage=torch.bfloat16,
+                )
+                if args.quantization == "4bit"
+                else None
+            ),
+        )
+
+    # build a tokenizer for possible resizing or saving
+    tokenizer = build_tokenizer(args)
+    # Note: resize_token_embeddings expects to receive the full size of the new vocabulary,
+    # i.e. the length of the tokenizer.
+    # 如果新增special tokens, 需要resize input embedding 和output embedding
+    # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
+
+    accelerator.print("Model load_in_4bit: ", args.quantization == "4bit")
+
+    if args.peft_type == "lora":
+        model.gradient_checkpointing_enable()
+    elif args.peft_type == "qlora":
+        # prepare base model for 4bit model(cast non-4bit layers to fp32)
+        model = prepare_model_for_kbit_training(model)
+        # logging.info(f"device map: {model.hf_device_map}")
+    else:
+        model.gradient_checkpointing_enable()
+        if args.saving_limit is None or not isinstance(args.saving_limit, int) or args.saving_limit < 1:
+            # saving_limit is set automatically if needed
+            args.saving_limit = 2
+            accelerator.print(
+                "[WARNING]saving_limit must be a integer greater than 1 in Full-Parameters Training, we set it to 2"
+            )
+
+    # Load PeftModel from a previous save or create a new PeftModel
+    if args.peft_type:
+        if not args.resume_from_checkpoint:
+            model = get_peft_model(model, peft_config)
+        else:
+            accelerator.print(f"[INFO] Resumed from checkpoint: {args.resume_from_checkpoint}")
+            # accelerator.load_state(args.resume_from_checkpoint)
+            model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True)
+
+        model.print_trainable_parameters()
+
+    t2 = time.time()
+    if accelerator.is_main_process:
+        logging.info(f"model loading time: {t2 - t1:.4f}")
+
+    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
+    if hasattr(model.config, "use_logn_attn"):
+        model.config.use_logn_attn = False  # special for qwen model
+    # load balance for moe training
+    if hasattr(model.config, "output_router_logits"):
+        model.config.output_router_logits = True
+    model_config = model.config
+    accelerator.print(model.config)
+
+    # dataloader
+    train_dataloader = DataLoader(
+        train_dataset,
+        shuffle=True,
+        collate_fn=DataCollatorForMFTDataset(args),
+        batch_size=args.per_device_train_batch_size,
+        pin_memory=True,
+        drop_last=True,
+    )
+    if valid_dataset:
+        valid_dataloader = DataLoader(
+            valid_dataset,
+            collate_fn=DataCollatorForMFTDataset(args),
+            batch_size=args.per_device_eval_batch_size,
+            pin_memory=True,
+            drop_last=True,
+        )
+    else:
+        valid_dataloader = None
+
+    # optimizer
+    if accelerator.distributed_type == DistributedType.DEEPSPEED:
+        accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED")
+        # from deepspeed.ops.adam import FusedAdam as Adam
+        # adam_optimizer = Adam
+        adam_optimizer = torch.optim.AdamW
+    elif accelerator.distributed_type == DistributedType.FSDP:
+        accelerator.print("DISTRIBUTED TRAINING USING FSDP")
+        if args.peft_type and getattr(accelerator.state, "fsdp_plugin", None) is not None:
+            from peft.utils.other import fsdp_auto_wrap_policy
+
+            accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
+        model = accelerator.prepare(model)
+        adam_optimizer = torch.optim.AdamW
+    else:
+        raise ValueError("Only support DeepSpeed and FSDP")
+
+    optimizer = adam_optimizer(
+        model.parameters(),
+        weight_decay=args.weight_decay,
+        lr=args.learning_rate,
+        betas=(0.9, 0.999),
+    )
+
+    # Scheduler and math around the number of training steps.
+    overrode_max_train_steps = False
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+        overrode_max_train_steps = True
+    if isinstance(args.num_warmup_steps, float) and args.num_warmup_steps < 1.0:
+        args.num_warmup_steps = int(args.max_train_steps * args.num_warmup_steps) // accelerator.num_processes
+    accelerator.print(f"num_warmup_steps: {args.num_warmup_steps}")
+    lr_scheduler = get_scheduler(
+        name=args.lr_scheduler_type,
+        optimizer=optimizer,
+        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes,
+        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+        # scheduler_specific_kwargs={"last_epoch": scheduler_last_ep}
+    )
+    # prepare all
+    if accelerator.distributed_type == DistributedType.DEEPSPEED:
+        if valid_dataloader:
+            (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare(
+                model, train_dataloader, valid_dataloader, optimizer, lr_scheduler
+            )
+        else:
+            (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare(
+                model, train_dataloader, optimizer, lr_scheduler
+            )
+
+    # prepare all except model, which is prepared before
+    elif accelerator.distributed_type == DistributedType.FSDP:
+        if valid_dataloader:
+            (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare(
+                optimizer, train_dataloader, valid_dataloader, lr_scheduler
+            )
+        else:
+            (optimizer, train_dataloader, lr_scheduler) = accelerator.prepare(
+                optimizer, train_dataloader, lr_scheduler
+            )
+    print(model.device)
+    accelerator.print(model)
+    # accelerator.print(model.config)
+
+    # Recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if overrode_max_train_steps:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+    # Afterward we recalculate our number of training epochs
+    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    # zero 3 flag
+    is_ds_zero_3 = False
+    if getattr(accelerator.state, "deepspeed_plugin", None):
+        is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3
+        accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}")
+    elif getattr(accelerator.state, "fsdp_plugin", None):
+        accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}")
+
+    trainer = MftTrainer(
+        accelerator=accelerator,
+        model=model,
+        model_config=model_config,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        optimizer=optimizer,
+        lr_scheduler=lr_scheduler,
+        tokenizer=tokenizer,
+        num_update_steps_per_epoch=num_update_steps_per_epoch,
+        total_train_dataset_size=len(train_dataset),
+        args=args,
+    )
+    trainer.accelerate_train()
+    logger.info(f"Training Finished!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/mftcoder_accelerate/src/pefts/mft_arguments.py b/mftcoder_accelerate/src/pefts/mft_arguments.py
new file mode 100644
index 0000000..9fee1cd
--- /dev/null
+++ b/mftcoder_accelerate/src/pefts/mft_arguments.py
@@ -0,0 +1,176 @@
+"""
+# @author Chaoyu Chen
+# @date 2023/10/19
+
+training arguments
+"""
+
+from dataclasses import dataclass, asdict
+from typing import List, Union
+
+
+@dataclass
+class MftTrainArgs:
+    # train data paths on shared FS
+    data_paths: Union[str, List[str]]
+
+    # output dir for saving adaptors in peft or full ckpts in full-parameter training
+    output_dir: str
+
+    # tensorboard dir for saving tensorboard logs
+    tb_dir: str
+
+    # pretrained_model_path, on which is the model you want to train
+    pretrained_model_path: str
+
+    # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2
+    model_type: str
+
+    # load from raw jsonl file or tokenized binary file
+    load_raw_dataset: bool = True
+
+    # weights of loss calculation for each task, None means equal weights
+    task_weights: Union[None, str] = None
+
+    # weights of data sampling, leave it None
+    data_weights: Union[None, str] = None
+
+    # hf loading model low_cpu_mem_usage
+    low_cpu_mem_usage: bool = True
+
+    # train/valid/test split
+    data_split: str = "98,2,0"
+
+    # padding or pack or concat
+    padding_mode: str = "padding"
+
+    # sft or sst
+    tokenize_mode: str = "sft"
+
+    # mft loss mode
+    weighted_loss_mode: str = "case3"
+
+    # lora or qlora or None(for full-parameter training)
+    peft_type: Union[None, str] = "qlora"
+
+    # if qlora, 4bit will be set, else None
+    quantization: Union[None, str] = "4bit"
+
+    # lora rank, the bigger, the more trainalbe parameters
+    lora_rank: int = 96
+
+    # lora alpha
+    lora_alpha: int = 32
+
+    # lora dropout
+    lora_dropout: float = 0.05
+
+    # lora targeting modules
+    target_modules: Union[None, str, List[str]] = None
+
+    # mircro train batch size
+    per_device_train_batch_size: int = 8
+
+    # micro eval batch size, always same as micro train batch size
+    per_device_eval_batch_size: int = 8
+
+    # HF AutoTokenizer is supported, maybe more types
+    tokenizer_type: str = "AutoTokenizer"
+
+    # initial lr
+    learning_rate: float = 5e-5
+
+    # minimum lr
+    min_lr: float = 5e-6
+
+    # weight decay
+    weight_decay: float = 0.01
+
+    # gradient_accumulation_steps
+    gradient_accumulation_steps: int = 1
+
+    # lr_scheduler_type
+    lr_scheduler_type: str = "cosine"
+
+    # num_warmup_steps
+    num_warmup_steps: Union[int, float] = 0.05
+
+    # num_train_epochs
+    num_train_epochs: int = 4
+
+    # seed for reproducing
+    seed: int = 1234
+
+    # seq_length, context length
+    seq_length: int = 4096
+
+    # path of adaptor which is resumed from, None for not resuming training
+    resume_from_checkpoint: Union[None, str] = None
+
+    # auto resume from latest ckpt if job restarted
+    auto_resume: bool = True
+
+    # num of steps for logging training loss
+    log_interval: int = 10
+
+    # num of steps for saving ckpt
+    checkpointing_steps: int = 100
+
+    # num of steps for evaluation(eval_loss), better same as checkpointing steps
+    evaluation_steps: int = 100
+
+    # max train steps, if None, depends on num_train_epochs
+    max_train_steps: Union[None, int] = None
+
+    # if checkpointing every epoch, maybe True in sst
+    epoch_checkpointing: bool = False
+
+    # shuffle before train/valid split
+    shuffle_before_split: bool = True
+
+    # DDP random sampler
+    use_random_sampler: bool = True
+
+    # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point
+    early_stopping: bool = True
+    early_stopping_stall_num: int = 5
+
+    # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota.
+    saving_limit: Union[None, int] = None
+
+    # if dynamic padding
+    use_dynamic_padding: bool = True
+
+    # warm-up steps for CoBa, recommand the number of valid batches
+    coba_warmup_steps: int = 100
+    # history length of sample valid loss used to fit the slope curve in CoBa, recommand [2*coba_warmup_steps,5*coba_warmup_steps]
+    coba_history_length: int = 200
+    # temperature for divergence factor in CoBa
+    coba_tau: int = 5
+    # iteration interval of update per task train weight in CoBa
+    coba_update_interval: int = 1
+    # the number of mini valid batches sampled at each updated iteration interval
+    coba_sample_valid_num: int = 1
+
+    # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2}
+    attn_implementation: str = "flash_attention_2"
+
+    # role markers, which are prompt template before each role: system, user and assistant
+    # role_markers: {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"}
+    role_markers: Union[None, dict] = None
+
+    distributed_type: Union[None, str] = None
+
+    init_timeout_seconds: Union[None, int] = 3600
+
+    # legacy, leave them
+    use_xformers: bool = True
+    trust_remote_code: bool = True
+    weight_by_num_documents: bool = True
+    make_vocab_size_divisible_by: int = 32
+    model_parallel_size: int = 1
+    use_slow_tokenizer: bool = False
+    world_size: int = 8
+
+    def dict(self):
+        return {k: str(v) for k, v in asdict(self).items()}
diff --git a/mftcoder_accelerate/src/pefts/mft_trainer.py b/mftcoder_accelerate/src/pefts/mft_trainer.py
new file mode 100644
index 0000000..a2b00fb
--- /dev/null
+++ b/mftcoder_accelerate/src/pefts/mft_trainer.py
@@ -0,0 +1,606 @@
+"""
+# @author qumu
+# @date 2024/4/12
+# @module trainer.py
+
+Accelerate + DeepSpeed/FSDP 
+QLoRA/LoRA/Full + SFT/MFT
+
+Trainer
+"""
+
+import gc
+import os
+import sys
+import threading
+import argparse
+import math
+import logging
+import json
+import time
+import transformers
+import numpy as np
+import psutil
+import shutil
+import torch
+from torch import nn
+from torch.utils.tensorboard import SummaryWriter
+from typing import List, Optional, Tuple, Union
+from tqdm.auto import tqdm
+from accelerate.logging import get_logger
+from accelerate import Accelerator
+from transformers import set_seed
+
+# sys.path.append("..")
+from utils.common_utils import generate_task_id, TASK2ID, ID2TASK
+from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func
+
+logger = get_logger(__name__)
+
+
+def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str):
+    # create path if not exist
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    # copy each file in files_list to save_path
+    for filename in files_list:
+        src_file = os.path.join(mode_path, filename)
+
+        # copy only if src exists
+        if os.path.exists(src_file):
+            dest_file = os.path.join(save_path, filename)
+
+            # copy
+            shutil.copy(src_file, dest_file)
+            print(f"Copied {filename} to {save_path}")
+        else:
+            print(f"File {filename} does not exist in {mode_path}")
+
+
+def check_existing_ckpts(output_dir):
+    prefix = "step_"
+
+    if not os.path.exists(output_dir):
+        return []
+    # list all files and dirs
+    contents = os.listdir(output_dir)
+
+    # find dirs starts with "step_"
+    matching_folders = [
+        folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix)
+    ]
+
+    return matching_folders
+
+
+def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps):
+    """
+    extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training
+    """
+    # Extract `epoch_{i}` or `step_{i}`
+    training_difference = os.path.splitext(path)[0]
+
+    if "epoch" in training_difference:
+        starting_epoch = int(training_difference.replace("epoch_", ""))
+        resume_step = None
+        completed_steps = starting_epoch * num_update_steps_per_epoch
+        logger.info(f"Resume from exact Epoch {starting_epoch}: completed_steps {completed_steps}")
+    else:
+        # need to multiply `gradient_accumulation_steps` to reflect real steps
+        completed_steps = int(training_difference.replace("step_", ""))
+        starting_epoch = completed_steps // num_update_steps_per_epoch
+        resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps
+        logger.info(f"Resume from Epoch {starting_epoch} + step {resume_step}: completed_steps {completed_steps}")
+
+    return starting_epoch, completed_steps, resume_step
+
+
+def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
+    for key, value in log_dict.items():
+        summary_writer.add_scalar(f"{key}", value, completed_steps)
+
+
+def delete_ckpts_over_limits(output_dir, saving_limit, best_step):
+    """delete ckpts more than saving_limits except for the best_step ckpt"""
+    existing_ckpts = check_existing_ckpts(output_dir)
+    logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}")
+    # sorted only step num ascendingly
+    ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts])
+    # delete the oldest steps except for the best step at present
+    if len(ckpt_steps) > saving_limit:
+        deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step]
+        # print(deletable_steps[:len(ckpt_steps) - saving_limit])
+        for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]:
+            shutil.rmtree(os.path.join(output_dir, f"step_{del_step}"))
+            logger.info(f"Removed ckpt step_{del_step}")
+
+
+class MftTrainer:
+    """
+    Multitask FineTuing Trainer, supporting MFT/SFT/ContinueTrain with Lora/Qlora/Full-parameters.
+    """
+
+    def __init__(
+        self,
+        accelerator: Accelerator,
+        model,
+        model_config,
+        train_dataloader,
+        valid_dataloader,
+        optimizer,
+        lr_scheduler,
+        tokenizer,
+        num_update_steps_per_epoch,
+        total_train_dataset_size,
+        args,
+    ):
+        self.accelerator = accelerator
+        self.model = model
+        # hf model config
+        self.model_config = model_config
+        self.train_dataloader = train_dataloader
+        self.valid_dataloader = valid_dataloader
+        self.optimizer = optimizer
+        self.lr_scheduler = lr_scheduler
+        self.tokenizer = tokenizer
+        self.num_update_steps_per_epoch = num_update_steps_per_epoch
+        self.total_train_dataset_size = total_train_dataset_size
+        # training arguments
+        self.args = args
+        # tensorboard writer
+        self.summary_writer = SummaryWriter(log_dir=args.tb_dir)
+
+    def print(self, msg: str):
+        """
+        accelerator print, default on main process
+        Args:
+            msg:
+
+        Returns:
+
+        """
+        self.accelerator.print(msg)
+
+    def touch(self, batch, num_tokens=10):
+        """touch first and last tokens and labels for debugging usage"""
+        self.print(
+            f"step 1 batch shape: {batch['input_ids'].shape},\n"
+            f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}"
+            f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}"
+        )
+        self.print(f"first {num_tokens} input_ids and loss_mask")
+        for pt in range(1):
+            self.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
+            self.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
+
+    @staticmethod
+    def format_tensor(tensor, n):
+        return list(map(lambda x: round(x, n), tensor.tolist()))
+
+    def accelerate_saving_checkpoint(self, output_dir: str, completed_steps: int):
+        """
+        Saving lora adaptor or full checkpoint using accelerator
+        Args:
+            output_dir: exact dir for saving ckpt
+            completed_steps:
+
+        Returns:
+
+        """
+        self.accelerator.wait_for_everyone()
+
+        logger.info(f"[CHECKPOINT] Saving checkpoint", main_process_only=True)
+        unwrapped_model = self.accelerator.unwrap_model(self.model)
+        # self.print(f"unwrapped model type {type(unwrapped_model)}")
+        unwrapped_model.save_pretrained(
+            output_dir,
+            is_main_process=self.accelerator.is_main_process,
+            save_function=self.accelerator.save,
+            state_dict=self.accelerator.get_state_dict(self.model),
+        )
+        self.accelerator.wait_for_everyone()
+        # for full-parameter training, save whole ckpt and tokenizer together because it does not need a merge.
+        if not self.args.peft_type and self.accelerator.is_main_process:
+            if self.args.model_type.lower() == "deepseek":
+                copy_tokenizer_files(
+                    self.args.pretrained_model_path, ["tokenizer.json", "tokenizer_config.json"], output_dir
+                )
+            else:
+                self.tokenizer.save_pretrained(output_dir)
+
+            sf = os.path.join(output_dir, "model.safetensors")
+            index_file = os.path.join(output_dir, "model.safetensors.index.json")
+            if os.path.isfile(sf) and os.path.isfile(index_file):
+                self.print(f"Remove bug dummy ckpt {sf}")
+                os.remove(sf)
+
+        if self.accelerator.is_main_process:
+            latest = {
+                "latest_ckpt": output_dir,
+                "lr": self.optimizer.param_groups[0]["lr"],
+            }
+            with open(os.path.join(self.args.output_dir, "latest"), "w") as f:
+                json.dump(latest, f, indent=2)
+
+            logger.info(
+                f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved, latest: {latest}",
+                main_process_only=True,
+            )
+        self.accelerator.wait_for_everyone()
+
+    def accelerate_monitor(
+        self,
+        reduce_loss,
+        reduce_task_loss,
+        reduce_task_exist,
+        completed_steps,
+        coba_status=None,
+    ):
+        """
+        gather reduce_loss and reduce_task_loss from all N devices.
+        train logging and tensorboarding.
+        """
+        # gather reduce_loss and reduce_task_loss from all N devices
+        reduce_losses = self.accelerator.gather(reduce_loss).detach().float()
+        reduce_task_losses = self.accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK))
+        reduce_task_exists = self.accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK))
+
+        # get train loss and per-task train loss
+        train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps)
+        # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (self.args.log_interval * self.args.gradient_accumulation_steps)
+        train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0)
+
+        # logging and writing tensorboard
+        logger.info(
+            f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]"
+            f"[train_task_loss={self.format_tensor(train_task_loss, 4)}]"
+            f"[gather shape={list(reduce_losses.shape)}]"
+            f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]",
+            main_process_only=True,
+        )
+        if coba_status is not None:
+            if completed_steps > coba_status.coba_warmup_steps:
+                coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum(
+                    coba_status.log_per_task_weight
+                )
+            else:
+                coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK)
+            logger.info(
+                f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True
+            )
+        train_log_dict = {"Loss/train": train_loss}
+        for i in range(len(ID2TASK)):
+            train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i]
+            if coba_status is not None:
+                train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item()
+
+        if self.accelerator.is_main_process:
+            write_tensorboard(self.summary_writer, train_log_dict, completed_steps)
+
+        if coba_status is not None:
+            coba_status.log_per_task_weight = torch.zeros(len(ID2TASK))
+
+    def accelerate_evaluate(
+        self,
+        completed_steps,
+        step,
+        min_eval_loss,
+        stall_num,
+        best_step,
+    ):
+        """
+        evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices.
+        eval logging and tensorboarding.
+        """
+        losses = []
+        accumulated_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+        accumulated_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+        for valid_step, valid_batch in enumerate(self.valid_dataloader):
+            with torch.no_grad():
+                outputs = self.model(
+                    input_ids=valid_batch["input_ids"],
+                    attention_mask=valid_batch["attention_mask"],
+                    position_ids=valid_batch["position_ids"],
+                    return_dict=True,
+                )
+
+                loss, task_loss, _ = loss_func_mft(
+                    outputs=outputs,
+                    labels=valid_batch["labels"],
+                    task_mask=valid_batch["task_mask"],
+                    task_id=valid_batch["task_id"],
+                    weighted_loss_mode=self.args.weighted_loss_mode,
+                    loss_mask=valid_batch["loss_mask"],
+                    task_weights=self.args.task_weights,
+                )
+
+                losses.append(self.accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size)))
+                accumulated_task_loss += task_loss.detach().float()
+                accumulated_task_exist += (task_loss != 0.0).detach().float()
+
+        self.accelerator.wait_for_everyone()
+        valid_batch_num = len(losses)
+        gathered_size = losses[0].shape
+        losses = torch.cat(losses)
+        # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK))
+        task_losses = self.accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK))
+        task_exists = self.accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK))
+
+        try:
+            eval_loss = torch.mean(losses)
+            # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num
+            eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0)
+            if eval_loss <= min_eval_loss:
+                min_eval_loss = eval_loss
+                stall_num = 0
+                best_step = completed_steps
+            else:
+                stall_num += 1
+            perplexity = math.exp(eval_loss)
+        except OverflowError:
+            perplexity = float("inf")
+
+        logger.info(
+            f"[EVAL][completed_steps={completed_steps}]"
+            f"[eval_loss={eval_loss:.6f}][eval_task_loss={self.format_tensor(eval_task_loss, 4)}]"
+            f"[perplexity={perplexity:.4f}][valid_batch_num={valid_batch_num}]"
+            f"[gather_size={list(gathered_size)}]",
+            main_process_only=True,
+        )
+        eval_log_dict = {
+            "Loss/valid": eval_loss,
+            "Perplexity/valid": perplexity,
+            "Epochs": round(completed_steps / self.num_update_steps_per_epoch, 2),
+        }
+        for i in range(len(ID2TASK)):
+            eval_log_dict[f"{ID2TASK[i]}_loss/valid"] = eval_task_loss[i]
+
+        if self.accelerator.is_main_process:
+            write_tensorboard(self.summary_writer, eval_log_dict, completed_steps)
+
+        return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step
+
+    def accelerate_train(self):
+        # Train!
+        if self.args.seed is not None:
+            set_seed(self.args.seed)
+
+        global_batch_size = (
+            self.args.per_device_train_batch_size
+            * self.accelerator.num_processes
+            * self.args.gradient_accumulation_steps
+        )
+        logger.info("************************************** Running training ****************************************")
+        logger.info(f"  Num examples = {self.total_train_dataset_size}")
+        logger.info(f"  Num Epochs = {self.args.num_train_epochs}")
+        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
+        logger.info(f"  Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}")
+        logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
+        logger.info(f"  Total optimization(update/completed) steps = {self.args.max_train_steps}")
+        logger.info(f"  Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_train_epochs}")
+        logger.info("************************************************************************************************")
+
+        # Only show the progress bar once on each machine.
+        progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process)
+
+        # set starting_epoch, completed_steps and resume_step of train_dataloader
+        completed_steps = 0
+        starting_epoch = 0
+        resume_step = None
+
+        if self.args.resume_from_checkpoint:
+            path = os.path.basename(self.args.resume_from_checkpoint)
+            starting_epoch, completed_steps, resume_step = extract_epochs_and_steps(
+                path, self.num_update_steps_per_epoch, self.args.gradient_accumulation_steps
+            )
+
+        # update the progress_bar if load from checkpoint
+        progress_bar.update(completed_steps)
+
+        # monitor minimum eval_loss, stalling num, and best_step
+        min_eval_loss = float("inf")
+        stall_num = 0
+        best_step = None
+
+        # monitor train loss
+        reduce_loss = torch.tensor(0.0).to(self.model.device)
+        reduce_aux_loss = torch.tensor(0.0).to(self.model.device)
+        reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+        reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+        per_task_weight = self.args.task_weights
+
+        if self.args.weighted_loss_mode == "coba":
+            self.model.eval()
+            eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate(
+                completed_steps,
+                0,
+                min_eval_loss,
+                stall_num,
+                best_step,
+            )
+            self.model.train()
+            coba_status = CoBaStatus(
+                self.args.coba_warmup_steps,
+                self.args.coba_history_length,
+                self.args.coba_tau,
+                self.args.coba_update_interval,
+                self.args.coba_sample_valid_num,
+                self.valid_dataloader,
+            )
+            coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device)
+            coba_status.sample_valid_batch(self.model, completed_steps)
+            logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True)
+        else:
+            coba_status = None
+
+        # Training Loop!
+        for epoch in range(starting_epoch, self.args.num_train_epochs):
+            # set_epoch
+            # self.train_dataloader.set_epoch(epoch)
+
+            # if we early stop by some ckpts not converging
+            if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num:
+                break
+
+            if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
+                # We skip the first `n` batches in the dataloader when resuming from a checkpoint
+                active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step)
+            else:
+                active_dataloader = self.train_dataloader
+            tail_num = len(active_dataloader) - len(active_dataloader) % self.args.gradient_accumulation_steps
+            print(f"length of dataloader: {len(active_dataloader)}")
+
+            self.model.train()
+            # Inner Loop!
+            for step, batch in enumerate(active_dataloader):
+                if step == tail_num:
+                    break
+                with self.accelerator.accumulate(self.model):
+                    if step == 0:
+                        self.touch(batch, num_tokens=10)
+                    # forward
+                    outputs = self.model(
+                        input_ids=batch["input_ids"],
+                        attention_mask=batch["attention_mask"],
+                        position_ids=batch["position_ids"],
+                        return_dict=True,
+                    )
+
+                    if (
+                        self.args.weighted_loss_mode == "coba"
+                        and self.accelerator.sync_gradients
+                        and completed_steps % self.args.coba_update_interval == 0
+                        and completed_steps >= self.args.coba_warmup_steps
+                    ):
+                        with torch.no_grad():
+                            per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps)
+                            coba_status.log_per_task_weight += per_task_weight
+                            # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True)
+
+                    # loss
+                    loss, task_loss, _ = loss_func_mft(
+                        outputs=outputs,
+                        labels=batch["labels"],
+                        task_mask=batch["task_mask"],
+                        task_id=batch["task_id"],
+                        weighted_loss_mode=self.args.weighted_loss_mode,
+                        loss_mask=batch["loss_mask"],
+                        task_weights=per_task_weight,
+                    )
+
+                    # accelerator.print(len(outputs.router_logits), outputs.router_logits[0], outputs.router_logits[-1])
+                    # accelerator.print(batch['attention_mask'].shape, batch['attention_mask'])
+                    aux_loss = None
+                    if hasattr(self.model_config, "output_router_logits") and self.model_config.output_router_logits:
+                        if hasattr(self.model_config, "num_local_experts"):
+                            num_experts = self.model_config.num_local_experts
+                        elif hasattr(self.model_config, "num_experts"):
+                            num_experts = self.model_config.num_experts
+                        else:
+                            raise ValueError("model has no attribute num_local_experts or num_experts")
+                        aux_loss = load_balancing_loss_func(
+                            outputs.router_logits,
+                            num_experts,
+                            self.model_config.num_experts_per_tok,
+                            batch["attention_mask"],
+                        )
+                        aux_loss = self.model_config.router_aux_loss_coef * aux_loss.to(loss.device)
+                        loss += aux_loss  # make sure to reside in the same device
+
+                    # backward
+                    self.accelerator.backward(loss)
+                    # print(self.lr_scheduler.state_dict(), self.accelerator.process_index)
+                    # update(sync_gradients)
+                    self.optimizer.step()
+                    self.lr_scheduler.step()
+                    self.optimizer.zero_grad()
+                    # support args.min_lr
+                    if self.optimizer.param_groups[0]["lr"] <= self.args.min_lr:
+                        self.optimizer.param_groups[0]["lr"] = self.args.min_lr
+
+                    # accumulate resuce_loss and reduce_task_loss in a log_interval
+                    if not torch.isnan(loss):
+                        reduce_loss += loss.detach().float()
+                    if aux_loss and not torch.isnan(aux_loss):
+                        reduce_aux_loss += aux_loss.detach().float()
+                    # self.print("task loss devices: ", reduce_task_loss.device, task_loss.device)
+                    reduce_task_loss += task_loss.detach().float()
+                    reduce_task_exist += (task_loss != 0).detach().float()
+
+                    # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done.
+                    if self.accelerator.sync_gradients:
+                        if (
+                            self.args.weighted_loss_mode == "coba"
+                            and completed_steps % self.args.coba_update_interval == 0
+                            and completed_steps >= 1
+                        ):
+                            coba_status.sample_valid_batch(self.model, completed_steps)
+                            # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True)
+
+                        # progress_bar.update(1)
+                        completed_steps += 1
+                        # monitoring training process and logging and tensorboarding
+                        if completed_steps % self.args.log_interval == 0:
+                            progress_bar.update(self.args.log_interval)
+                            if reduce_aux_loss > 0.0:
+                                self.print(f"[INFO] aux_loss: {reduce_aux_loss/self.args.log_interval}")
+                            self.accelerate_monitor(
+                                reduce_loss,
+                                reduce_task_loss,
+                                reduce_task_exist,
+                                completed_steps,
+                                coba_status,
+                            )
+                            # reset reduce_loss
+                            reduce_loss = torch.tensor(0.0).to(self.model.device)
+                            reduce_aux_loss = torch.tensor(0.0).to(self.model.device)
+                            reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+                            reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+
+                        # steps checkpointing
+                        if self.args.checkpointing_steps and completed_steps % self.args.checkpointing_steps == 0:
+                            output_dir = f"step_{completed_steps}"
+                            if self.args.output_dir is not None:
+                                output_dir = os.path.join(self.args.output_dir, output_dir)
+                            self.accelerate_saving_checkpoint(output_dir, completed_steps)
+
+                        # steps evaluation
+                        if completed_steps % self.args.evaluation_steps == 0 and self.valid_dataloader:
+                            self.model.eval()
+                            eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate(
+                                completed_steps,
+                                step,
+                                min_eval_loss,
+                                stall_num,
+                                best_step,
+                            )
+                            self.model.train()
+
+                            # delete ckpts over args.saving_limit
+                            if self.accelerator.is_main_process and self.args.saving_limit:
+                                delete_ckpts_over_limits(self.args.output_dir, self.args.saving_limit, best_step)
+
+                            # early stoppin when stalling more than args.early_stopping_stall_num
+                            if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num:
+                                self.print(f"[WARNING] Early stopping at {completed_steps}")
+                                break
+
+                        if completed_steps >= self.args.max_train_steps:
+                            break
+                        self.accelerator.wait_for_everyone()
+
+            # epoch checkpointing
+            if self.args.epoch_checkpointing:
+                output_dir = f"epoch_{epoch + 1}"
+                if self.args.output_dir is not None:
+                    output_dir = os.path.join(self.args.output_dir, output_dir)
+                self.accelerate_saving_checkpoint(output_dir, completed_steps)
+
+        self.summary_writer.close()
+
+        # final save
+        # output_dir = f"final_step_{completed_steps}"
+        # if self.args.output_dir is not None:
+        #     output_dir = os.path.join(self.args.output_dir, output_dir)
+        # self.accelerate_saving_checkpoint(output_dir, completed_steps)
diff --git a/mftcoder_accelerate/src/run_offline_tokenization.sh b/mftcoder_accelerate/src/run_offline_tokenization.sh
new file mode 100644
index 0000000..ed916da
--- /dev/null
+++ b/mftcoder_accelerate/src/run_offline_tokenization.sh
@@ -0,0 +1,13 @@
+MODEL_PATH=
+DATA_PATH=
+DATASET_NAME=
+OUTPUT_PATH=
+
+python offline_tokenization/concat_sst_bin_tokenization.py \
+--model-path ${MODEL_PATH} \
+--data-path ${DATA_PATH} \
+--dataset-name ${DATASET_NAME} \
+--output-path ${OUTPUT_PATH} \
+--parallel 16 \
+--seq-length 4096 \
+--sample-percent 1.0
diff --git a/mftcoder_accelerate/src/tokenizer/__init__.py b/mftcoder_accelerate/src/tokenizer/__init__.py
new file mode 100644
index 0000000..20e88bb
--- /dev/null
+++ b/mftcoder_accelerate/src/tokenizer/__init__.py
@@ -0,0 +1,3 @@
+from .tokenizer import build_tokenizer
+from .tokenizer import init_tokenizer
+from .chat_template import MFTCoder_template
\ No newline at end of file
diff --git a/mftcoder_accelerate/src/tokenizer/chat_template.py b/mftcoder_accelerate/src/tokenizer/chat_template.py
new file mode 100644
index 0000000..3d2ad03
--- /dev/null
+++ b/mftcoder_accelerate/src/tokenizer/chat_template.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# @author Chaoyu Chen
+# @date 2023/12/25
+
+# store possible chat_template for tokenizers to prepare input string
+# -------------------------------------------------- Import ------------------------------------------------------------
+"""
+Usage:
+tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
+messages = [
+    {"role": "system", "content": "Be smart"},
+    {"role": "human", "content": "Hello, how are you?"},
+    {"role": "bot", "content": "I'm doing great. How can I help you today?"},
+    {"role": "human", "content": "I'd like to show off how chat templating works!"},
+]
+prompts = tokenizer.apply_chat_template(message, chat_template=MFTCoder_template, tokenize=False, add_generation_prompt=True)
+"""
+
+MFTCoder_template = (
+    "{% if messages[0]['role'] == 'system' %}"
+    "{% set loop_messages = messages[1:] %}"  # Extract system message if it's present
+    "{% set system_message = messages[0]['content'] %}"
+    "{% else %}"
+    "{% set loop_messages = messages %}"
+    "{% set system_message = false %}"
+    "{% endif %}"
+    "{% for message in loop_messages %}"  # Loop over all non-system messages
+    "{% if (message['role'] == 'user' or message['role'] == 'human') != (loop.index0 % 2 == 0) %}"
+    "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+    "{% endif %}"
+    "{% if loop.index0 == 0 and system_message != false %}"  # Embed system message in first message
+    "{% set content = 'system\n' + system_message + '\n' %}"
+    "{% else %}"
+    "{% set content = '' %}"
+    "{% endif %}"
+    "{% if message['role'] == 'user' or message['role'] == 'human' %}"
+    "{{ content + 'human\n' + message['content'] + '\n' }}"
+    "{% elif message['role'] == 'assistant' or message['role'] == 'bot' %}"
+    "{{ 'bot\n' + message['content'] + '\n' +  eos_token + '\n'}}"
+    "{% else %}"
+    "{{ raise_exception('Only user/human and assistant/bot roles are supported!') }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ 'bot\n' }}"
+    "{% endif %}"
+)
+
+if __name__ == "__main__":
+    pass
diff --git a/mftcoder_accelerate/src/tokenizer/tokenizer.py b/mftcoder_accelerate/src/tokenizer/tokenizer.py
new file mode 100644
index 0000000..bc3ab56
--- /dev/null
+++ b/mftcoder_accelerate/src/tokenizer/tokenizer.py
@@ -0,0 +1,93 @@
+"""
+# @author Chaoyu Chen
+# @date 2023/6/19
+"""
+
+import numpy as np
+from typing import List, Union
+from utils.common_utils import print_rank_0
+from transformers import AutoTokenizer, AutoConfig
+from tokenizer.chat_template import MFTCoder_template
+
+
+def init_tokenizer(path):
+    """
+    Init a Huggingface tokenizer, parsing eos_token from the tokenizer_config then config.
+    Set pad_token same as eos_token for easy life.
+    :param path: model path or tokenizer path
+    :return: Tokenizer (TokenizerFast is preferred)
+    """
+    # tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False, legacy=False)
+    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
+    config, unused_kwargs = AutoConfig.from_pretrained(path, trust_remote_code=True, return_unused_kwargs=True)
+
+    if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
+        print(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
+        eos_token_id = tokenizer.eos_token_id
+        eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
+    elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token:
+        print(f"Initial eos_token {tokenizer.eos_token} from tokenizer")
+        eos_token = tokenizer.eos_token
+        eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
+    elif hasattr(config, "eos_token_id") and config.eos_token_id:
+        print(f"Initial eos_token_id {config.eos_token_id} from config.json")
+        eos_token_id = config.eos_token_id
+        eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id)
+    elif hasattr(config, "eos_token") and config.eos_token:
+        print(f"Initial eos_token {config.eos_token} from config.json")
+        eos_token = config.eos_token
+        eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token)
+    else:
+        raise ValueError(
+            "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json"
+        )
+    try:
+        tokenizer.eos_token = eos_token
+        tokenizer.eos_token_id = eos_token_id
+        # set pad_token to be same as eos_token, it is ok because is will be masked out.
+        tokenizer.pad_token = eos_token
+        tokenizer.pad_token_id = eos_token_id
+    except:
+        print(f"[WARNING]Cannot set tokenizer.eos_token")
+
+    tokenizer.add_bos_token = False
+    tokenizer.add_eos_token = False
+    tokenizer.chat_template = MFTCoder_template
+    print_rank_0(f"Tokenizer: {type(tokenizer)}")
+    print_rank_0(f"Length of tokenizer: {len(tokenizer)}")
+    print_rank_0(f"build_tokenizer pad_token_id: {tokenizer.pad_token_id}, eos_token_id: {tokenizer.eos_token_id}")
+    print_rank_0(f"build_tokenizer pad_token : {tokenizer.pad_token}, eos_token: {tokenizer.eos_token}")
+
+    return tokenizer
+
+
+def build_tokenizer(args):
+    """Initialize tokenizer."""
+    print_rank_0(f"> building {args.tokenizer_type} tokenizer ...")
+    # Select and instantiate the tokenizer.
+    if args.tokenizer_type.lower() == "AutoTokenizer".lower():
+        assert args.pretrained_model_path is not None
+        tokenizer = init_tokenizer(args.pretrained_model_path)
+    else:
+        raise NotImplementedError(f"{args.tokenizer_type} tokenizer is not implemented.")
+
+    # Add vocab size.
+    args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
+
+    return tokenizer
+
+
+def _vocab_size_with_padding(orig_vocab_size, args):
+    """Pad vocab size thus it is divisible by model parallel size and
+    still having GPU friendly size."""
+
+    after = orig_vocab_size
+    multiple = args.make_vocab_size_divisible_by * args.model_parallel_size
+    while (after % multiple) != 0:
+        after += 1
+    print_rank_0(
+        " > padded vocab (size: {}) with {} dummy tokens "
+        "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after)
+    )
+
+    return after
diff --git a/mftcoder_accelerate/src/utils/__init__.py b/mftcoder_accelerate/src/utils/__init__.py
new file mode 100644
index 0000000..0bd6cec
--- /dev/null
+++ b/mftcoder_accelerate/src/utils/__init__.py
@@ -0,0 +1,2 @@
+from .common_utils import *
+from .loss_utils import *
diff --git a/mftcoder_accelerate/src/utils/agd.py b/mftcoder_accelerate/src/utils/agd.py
new file mode 100644
index 0000000..11929e3
--- /dev/null
+++ b/mftcoder_accelerate/src/utils/agd.py
@@ -0,0 +1,138 @@
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch import Tensor
+
+Params = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
+
+LossClosure = Callable[[], float]
+OptLossClosure = Optional[LossClosure]
+Betas2 = Tuple[float, float]
+State = Dict[str, Any]
+OptFloat = Optional[float]
+Nus2 = Tuple[float, float]
+
+__all__ = ("AGD",)
+
+
+class AGD(torch.optim.Optimizer):
+    r"""AGD: an Auto-switchable Optimizer using Stepwise Gradient Difference as Preconditioning Matrix.
+    Arguments:
+        params (Params): Collection of parameters to be optimized, or an iterable of dictionaries specifying separate groups.
+        lr (float, optional): The learning rate. Default is 1e-3.
+        betas (tuple of 2 floats, optional): Coefficients used for computing running averages of gradient and its square. Default is (0.9, 0.999).
+        delta (float, optional): Small constant for numerical stability to prevent division by zero. Default is 1e-5.
+        weight_decay (float, optional): Weight decay coefficient. Default is 0.0.
+        amsgrad (bool, optional): If set to True, applies the AMSGrad variant of the optimizer. Default is False.
+        win (bool, optional): If set to True, applies the Win variant of the optimizer. Default is False.
+        clip (bool, optional): Total update clip to prevent abnormal updates. Default is None.
+    """
+
+    def __init__(
+        self,
+        params: Params,
+        lr: float = 1e-3,
+        betas: Betas2 = (0.9, 0.999),
+        delta: float = 1e-5,
+        weight_decay: float = 0.0,
+        amsgrad: bool = False,
+        win: bool = False,
+        clip: float = None,
+    ) -> None:
+        if lr <= 0.0:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if delta < 0.0:
+            raise ValueError("Invalid delta value: {}".format(delta))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        if weight_decay < 0.0:
+            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            delta=delta,
+            weight_decay=weight_decay,
+            amsgrad=amsgrad,
+            win=win,
+            clip=clip,
+        )
+        super(AGD, self).__init__(params, defaults)
+
+    def step(self, closure: OptLossClosure = None) -> OptFloat:
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            beta1, beta2 = group["betas"]
+
+            for p in group["params"]:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    msg = "AGD does not support sparse gradients."
+                    raise RuntimeError(msg)
+
+                state = self.state[p]
+                # Lazy state initialization
+                if len(state) == 0:
+                    state["step"] = 0
+                    # Exponential moving average of gradient values
+                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    # Exponential moving average of squared gradient values
+                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    if group["amsgrad"]:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    if group["win"]:
+                        state["z"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+                exp_avg, exp_avg_sq = (
+                    state["exp_avg"],
+                    state["exp_avg_sq"],
+                )
+
+                state["step"] += 1
+                exp_avg_old = exp_avg.detach().clone()
+                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+                bias_correction1_old = 1 - beta1 ** (state["step"] - 1)
+                bias_correction1, bias_correction2 = (
+                    1 - beta1 ** state["step"],
+                    1 - beta2 ** state["step"],
+                )
+                update = (
+                    exp_avg * (1 / bias_correction1)
+                    if state["step"] == 1
+                    else exp_avg * (1 / bias_correction1) - exp_avg_old * (1 / bias_correction1_old)
+                )
+                exp_avg_sq.mul_(beta2).addcmul_(update, update, value=1 - beta2)
+
+                if group["amsgrad"]:
+                    max_exp_avg_sq = state["max_exp_avg_sq"]
+                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+                    update = max_exp_avg_sq.sqrt()
+                else:
+                    update = exp_avg_sq.sqrt()
+
+                delta_adjust = group["delta"] * np.sqrt(bias_correction2)
+                update.clamp_(min=delta_adjust)
+
+                lr_adjust = group["lr"] * np.sqrt(bias_correction2) / bias_correction1
+                update = exp_avg / update
+                if group["clip"] is not None:
+                    update.clamp_(min=-group["clip"], max=group["clip"])
+                weight_decay = group["weight_decay"]
+                if not group["win"]:
+                    p.data.mul_(1 - group["lr"] * weight_decay).add_(update, alpha=-lr_adjust)
+                else:
+                    z = state["z"]
+                    z.data.add_(update, alpha=-lr_adjust).mul_(1.0 / (1.0 + weight_decay * lr_adjust))
+                    lr_adjust2 = 2 * lr_adjust
+                    tao = 1.0 / (3.0 + lr_adjust2 * weight_decay)
+                    p.data.mul_(tao).add_(update, alpha=-tao * lr_adjust2).add_(z, alpha=2 * tao)
+        return loss
diff --git a/mft_peft_hf/src/utils/common_utils.py b/mftcoder_accelerate/src/utils/common_utils.py
similarity index 73%
rename from mft_peft_hf/src/utils/common_utils.py
rename to mftcoder_accelerate/src/utils/common_utils.py
index 48d75e1..7b6ea30 100644
--- a/mft_peft_hf/src/utils/common_utils.py
+++ b/mftcoder_accelerate/src/utils/common_utils.py
@@ -1,10 +1,29 @@
 import os
 import math
 import torch
+from packaging import version
+import importlib
 
 TASK2ID = {}
 ID2TASK = {}
 
+
+def is_flash_attn_2_available():
+
+    # Let's add an extra check to see if cuda is available
+
+    if not torch.cuda.is_available():
+        return False
+
+    if torch.version.cuda:
+        return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
+    elif torch.version.hip:
+        # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
+        return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
+    else:
+        return False
+
+
 def print_rank_0(*message):
     """If distributed is initialized print only on rank 0."""
     if torch.distributed.is_initialized():
@@ -92,30 +111,24 @@ def get_tflops_new(args, batch_size, seq_len, step_time):
     L = args.num_hidden_layers
     h = args.hidden_size
     V = args.vocab_size
-    flops = (96 * batch_size * sl * L * h * h * (1 + sl / (6 * h) + V / (16 * L * h)) / step_time)
+    flops = 96 * batch_size * sl * L * h * h * (1 + sl / (6 * h) + V / (16 * L * h)) / step_time
     return human_readable_flops(flops)
 
 
-def get_tflops_megatron(total_model_param, hidden_size, num_hidden_layers, 
-                        batch_size_per_device, seq_len, step_time):
+def get_tflops_megatron(total_model_param, hidden_size, num_hidden_layers, batch_size_per_device, seq_len, step_time):
 
     ff = total_model_param * 6
     attn = seq_len * hidden_size * num_hidden_layers * 60
-    flops = (
-        batch_size_per_device
-        * seq_len
-        * (ff + attn)
-        / step_time
-    )
+    flops = batch_size_per_device * seq_len * (ff + attn) / step_time
     return human_readable_flops(flops)
 
 
 def generate_task_id(data_paths):
-    data_prefixes = list(data_paths[1:-1].split(','))
+    data_prefixes = list(data_paths[1:-1].split(","))
     print("data paths: ")
     print(data_prefixes)
 
     for i, prefix in enumerate(data_prefixes):
-        task_name = prefix.split('/')[-1]
+        task_name = prefix.split("/")[-1]
         TASK2ID[task_name] = i
         ID2TASK[i] = task_name
diff --git a/mftcoder_accelerate/src/utils/loss_utils.py b/mftcoder_accelerate/src/utils/loss_utils.py
new file mode 100644
index 0000000..5ca7c73
--- /dev/null
+++ b/mftcoder_accelerate/src/utils/loss_utils.py
@@ -0,0 +1,365 @@
+import sys
+import torch
+from utils.common_utils import print_rank_0, TASK2ID, ID2TASK
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+from dataclasses import dataclass
+import numpy as np
+from typing import List, Optional, Tuple, Union
+
+
+def get_task_mask(task_id):
+    task_num = len(TASK2ID)
+    task_mask = torch.zeros(task_id.shape[0], task_num)
+    task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1
+
+    return task_mask
+
+
+def get_task_loss(task_losses, task_id):  # TODO
+    # fix task order
+    task_loss_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device)
+    # count task samples
+    task_num_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device)
+    for i in range(len(task_id)):
+        task_num_per_batch[task_id[i][0]] += 1
+        task_loss_per_batch[task_id[i][0]] = task_losses[task_id[i][0]]
+
+    return task_loss_per_batch, task_num_per_batch
+
+
+def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_mask=None, task_weights=None):
+    """
+    loss function for MFT loss
+    :param outputs:
+    :param labels:
+    :param task_mask:
+    :param task_id:
+    :param weighted_loss_mode:
+    :param loss_mask:
+    :return:
+    """
+    # task_id shape: [[1], [2], [4], [3], ..., [1]]
+    weighted = weighted_loss_mode
+    lm_logits = outputs["logits"]
+    labels = labels.to(device=lm_logits.device)
+    task_mask = task_mask.to(device=lm_logits.device)
+    task_id = task_id.to(device=lm_logits.device)
+    shift_logits = lm_logits.contiguous()
+    labels = labels.contiguous()
+    if task_weights is None:
+        task_weights = torch.ones(len(ID2TASK)).to(device=lm_logits.device) / len(ID2TASK)
+
+    bsz, seq_len = labels.shape
+    # loss_mask = None
+    if loss_mask is None:
+        ineffective_tokens_per_sample = (labels == -100).sum(dim=1)
+        effective_tokens_per_sample = -(ineffective_tokens_per_sample - seq_len)
+        effective_tokens = bsz * seq_len - ineffective_tokens_per_sample.sum()
+        loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-100)
+    else:
+        loss_mask = loss_mask.to(device=lm_logits.device)
+        loss_fct = CrossEntropyLoss(reduction="none")
+    losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))  # [B * L, 1]
+    losses = losses.contiguous().view(bsz, -1)
+    token_losses = (
+        losses.clone().detach().float() if loss_mask is None else losses.clone().detach().float() * loss_mask
+    )  # [B, L]
+    task_mask_trans = torch.transpose(task_mask, 0, 1)
+    unique_id = torch.unique(task_id)
+    if weighted_loss_mode == "case3" or weighted_loss_mode == "case4" or weighted_loss_mode == "coba":
+        loss = 0.0
+        weights_sum = 0.0
+        for i, w in enumerate(unique_id):
+            row_idx = torch.squeeze(task_id) == w.item()
+            task_weight = float(task_weights[w.item()])
+            weights_sum += task_weight
+            if weighted_loss_mode == "case3" or weighted_loss_mode == "coba":
+                if loss_mask is None:
+                    loss += (
+                        torch.sum(losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) * task_weight
+                    )
+                else:
+                    loss += torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) * task_weight
+            elif weighted_loss_mode == "case4":
+                if loss_mask is None:
+                    loss += (
+                        torch.mean(torch.sum(losses, dim=1)[row_idx] / effective_tokens_per_sample[row_idx])
+                        * task_weight
+                    )
+                else:
+                    loss += (
+                        torch.mean(torch.sum(losses * loss_mask, dim=1)[row_idx] / torch.sum(loss_mask, dim=1)[row_idx])
+                        * task_weight
+                    )
+
+        # loss /= len(unique_id)
+        loss /= weights_sum
+
+    elif weighted_loss_mode == "case2":
+        if loss_mask is None:
+            loss = torch.mean(torch.sum(losses, dim=1) / effective_tokens_per_sample)
+        else:
+            loss = torch.mean(torch.sum(losses * loss_mask, dim=1) / torch.sum(loss_mask, dim=1))
+    elif weighted_loss_mode == "case1":
+        # flatten losses & loss_mask tensor
+        if loss_mask is None:
+            # losses = losses.view(-1)
+            loss = torch.sum(losses.view(-1)) / effective_tokens
+        else:
+            # loss_mask = loss_mask.view(-1)
+            # losses = losses.view(-1)
+            loss = torch.sum(losses.view(-1) * loss_mask.view(-1)) / loss_mask.view(-1).sum()
+
+    # fix task order
+    task_loss = torch.zeros(len(ID2TASK)).to(device=task_id.device)
+    task_num = torch.zeros(len(ID2TASK)).to(device=task_id.device)
+    for i, w in enumerate(unique_id):
+        row_idx = torch.squeeze(task_id) == w.item()
+        if loss_mask is None:
+            task_loss[w] = torch.sum(token_losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx])
+            task_num[w] = len(effective_tokens_per_sample[row_idx])
+        else:
+            task_loss[w] = torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :])
+
+    return loss, task_loss, task_num
+
+
+def load_balancing_loss_func(
+    gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
+) -> float:
+    r"""
+    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+    experts is too unbalanced.
+
+    Args:
+        gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
+            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+            shape [batch_size X sequence_length, num_experts].
+        attention_mask (`torch.Tensor`, None):
+            The attention_mask used in forward function
+            shape [batch_size X sequence_length] if not None.
+        num_experts (`int`, *optional*):
+            Number of experts
+
+    Returns:
+        The auxiliary loss.
+    """
+    if gate_logits is None or not isinstance(gate_logits, tuple):
+        return 0
+
+    if isinstance(gate_logits, tuple):
+        compute_device = gate_logits[0].device
+        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+    if attention_mask is None:
+        # Compute the percentage of tokens routed to each experts
+        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+        # Compute the average probability of routing to these experts
+        router_prob_per_expert = torch.mean(routing_weights, dim=0)
+    else:
+        batch_size, sequence_length = attention_mask.shape
+        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+        expert_attention_mask = (
+            attention_mask[None, :, :, None, None]
+            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+            .reshape(-1, top_k, num_experts)
+            .to(compute_device)
+        )
+
+        # Compute the percentage of tokens routed to each experts
+        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+            expert_attention_mask, dim=0
+        )
+
+        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+        router_per_expert_attention_mask = (
+            attention_mask[None, :, :, None]
+            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+            .reshape(-1, num_experts)
+            .to(compute_device)
+        )
+
+        # Compute the average probability of routing to these experts
+        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+            router_per_expert_attention_mask, dim=0
+        )
+
+    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+    return overall_loss * num_experts
+
+
+class MFTLossStatus:
+    def __init__(self):
+        super(MFTLossStatus, self).__init__()
+
+
+class CoBaStatus(MFTLossStatus):
+    def __init__(
+        self,
+        coba_warmup_steps=100,
+        coba_history_length=200,
+        coba_tau=5,
+        coba_update_interval=1,
+        coba_sample_valid_num=1,
+        valid_dataloader=None,
+    ):
+
+        super(CoBaStatus, self).__init__()
+        self.coba_warmup_steps = coba_warmup_steps
+        self.coba_history_length = coba_history_length
+        self.coba_tau = coba_tau
+        self.coba_update_interval = coba_update_interval
+        self.coba_sample_valid_num = coba_sample_valid_num
+        self.valid_dataloader = valid_dataloader
+        self.valid_dataloader_length = len(valid_dataloader)
+        self.valid_iterator = iter(valid_dataloader)
+        self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK))
+        self.history_task_valid_loss = None
+        self.per_task_slope_list = None
+        self.total_slope_list = None
+        self.minimum_weight = 1 / (len(ID2TASK) * 10)
+        self.valid_task_loss_begining = torch.ones(len(ID2TASK), dtype=torch.float64)
+        self.log_per_task_weight = torch.zeros(len(ID2TASK))
+
+    def coba_evaluate(self, model, v_batch, per_task_weight=None, coba_status=None):
+        model.eval()
+        with torch.no_grad():
+            valid_outputs = model(
+                input_ids=v_batch["input_ids"],
+                attention_mask=v_batch["attention_mask"],
+                position_ids=v_batch["position_ids"],
+            )
+
+            _, valid_task_loss, valid_task_num = loss_func_mft(
+                outputs=valid_outputs,
+                labels=v_batch["labels"],
+                task_mask=v_batch["task_mask"],
+                task_id=v_batch["task_id"],
+                weighted_loss_mode="coba",
+                loss_mask=v_batch["loss_mask"],
+                task_weights=None,
+            )
+
+            task_exist = (valid_task_loss != 0.0).float()
+            torch.distributed.all_reduce(valid_task_loss, op=torch.distributed.ReduceOp.SUM)
+            torch.distributed.all_reduce(task_exist, op=torch.distributed.ReduceOp.SUM)
+            valid_task_loss /= task_exist.clamp_(1.0)
+            valid_task_loss /= self.valid_task_loss_begining
+        model.train()
+        return valid_task_loss
+
+    def compute_per_task_weight(self, completed_steps=None):
+        task_num = len(ID2TASK)
+        task_slope_fitting = torch.ones(task_num, dtype=torch.float64)
+        start_step = max(0, completed_steps // self.coba_update_interval - self.coba_history_length)
+        history_steps = torch.arange(start_step, completed_steps, 1)
+        for i in range(task_num):
+            per_task_history_valid_loss = self.history_task_valid_loss[i][-len(history_steps):]
+            task_slope_fitting[i] = self.fit_window_slope(
+                history_steps, per_task_history_valid_loss, type="slope"
+            )
+        history_total_valid_loss, index = torch.max(self.history_task_valid_loss[:, -len(history_steps):], dim=0)
+        total_slope_fitting = self.fit_window_slope(
+            history_steps, history_total_valid_loss, type="slope"
+        )
+        if completed_steps == self.coba_warmup_steps:
+            self.per_task_slope_list = task_slope_fitting.unsqueeze(1)
+            self.total_slope_list = total_slope_fitting.unsqueeze(0)
+        else:
+            self.per_task_slope_list = torch.cat((self.per_task_slope_list, task_slope_fitting.unsqueeze(1)), dim=-1)
+            self.total_slope_list =  torch.cat((self.total_slope_list, total_slope_fitting.unsqueeze(0)), dim=0)
+        
+        # Relative Convergence Score
+        normalize_task_slope = task_num * task_slope_fitting / task_slope_fitting.abs().sum()
+        rcs = F.softmax(normalize_task_slope, dim=-1)
+        
+        # Absolute Convergence Score
+        history_per_task_slope_list = self.per_task_slope_list[:, start_step:]
+        reverse_normailize_iter_slope = -len(history_per_task_slope_list[0]) * history_per_task_slope_list \
+                                        / history_per_task_slope_list.abs().sum(dim=-1, keepdim=True)
+
+        flatten_rn_iter_slope = reverse_normailize_iter_slope.T.reshape(-1)
+        current_step_rn_slope = flatten_rn_iter_slope[-task_num:]
+        acs = F.softmax(current_step_rn_slope, dim=-1)
+
+        # Divergence Factor
+        normalize_total_iter_slope = - len(self.total_slope_list) * self.total_slope_list \
+                                     / self.total_slope_list.abs().sum()
+        divergence_factor = F.softmax(normalize_total_iter_slope * self.coba_tau, dim=-1)[-1] \
+                          * len(self.total_slope_list)
+
+        weight_logits = divergence_factor * rcs + (1 - divergence_factor) * acs
+        per_task_weight = F.softmax(weight_logits * task_num, dim=-1)
+
+        if len((per_task_weight < self.minimum_weight).nonzero().squeeze(0)) > 0:
+            per_task_weight = per_task_weight * (1 - self.minimum_weight * task_num)
+            per_task_weight += self.minimum_weight
+
+        return per_task_weight
+    
+    def fit_window_slope(self, x, y, type="slope"):
+
+        y = y[y != 0]
+        x = x[:len(y)]
+        
+        nonzero_index = torch.squeeze(torch.nonzero(y), dim=1)
+        y = torch.index_select(y, 0, nonzero_index)
+        x = torch.index_select(x, 0, nonzero_index)
+
+        ws = torch.flip(1 ** torch.arange(len(y)), dims=[0])
+        ws = ws.double()
+
+        if len(y) >= 2:
+            if type == "slope":
+                X = torch.stack((x, torch.ones_like(x, dtype=torch.float64))).T
+                X = X.double()
+            else:
+                X = torch.stack((x ** 2, x, torch.ones_like(x, dtype=torch.float64))).T
+
+            # implementation for numpy
+            # X_np = X.T @ (ws[:, None] * X)
+            # Y_np = X.T @ (ws * y)
+            # w = torch.from_numpy(np.linalg.solve(X_np.numpy(), Y_np.numpy()))
+
+            # implementation for torch
+            w = torch.linalg.solve(X.T @ (ws[:, None] * X), X.T @ (ws * y))
+
+            result = w[0]
+        else:
+            result = 0.0
+
+        return result
+
+    def sample_valid_batch(self, model, completed_steps):
+        self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK), dtype=torch.float64)
+        for i in range(self.coba_sample_valid_num):
+            if (
+                self.coba_sample_valid_num * completed_steps // self.coba_update_interval + i
+            ) % self.valid_dataloader_length == 0:
+                self.valid_iterator = iter(self.valid_dataloader)
+                v_batch = next(self.valid_iterator)
+            else:
+                v_batch = next(self.valid_iterator)
+            valid_task_loss = self.coba_evaluate(model, v_batch)
+            self.valid_task_loss_accumulated += valid_task_loss.detach().cpu().double()
+
+        self.valid_task_loss_accumulated /= self.coba_sample_valid_num
+        if self.history_task_valid_loss is None and completed_steps >= 1:
+            self.history_task_valid_loss = self.valid_task_loss_accumulated.unsqueeze(1)
+        elif self.history_task_valid_loss is not None:
+            self.history_task_valid_loss = torch.cat(
+                (self.history_task_valid_loss, self.valid_task_loss_accumulated.unsqueeze(1)), dim=-1
+            )
diff --git a/mftcoder_accelerate/src/utils/model_mapping.py b/mftcoder_accelerate/src/utils/model_mapping.py
new file mode 100644
index 0000000..8592e86
--- /dev/null
+++ b/mftcoder_accelerate/src/utils/model_mapping.py
@@ -0,0 +1,67 @@
+"""
+ @author qumu
+ transformers==4.40 is stable now
+"""
+
+# Models that Transformers support Code and FA2 when flash_attn>=2.1.0
+from transformers import (
+    GPTNeoXForCausalLM,
+    GPTBigCodeForCausalLM,
+    LlamaForCausalLM,
+    MistralForCausalLM,
+    MixtralForCausalLM,
+    PhiForCausalLM,
+    GemmaForCausalLM,
+    Qwen2ForCausalLM,
+    Qwen2MoeForCausalLM,
+    Starcoder2ForCausalLM,
+)
+
+# model in local model dir and support transformers FA2
+from model.deepseek_v2.modeling_deepseek import DeepseekV2ForCausalLM
+
+# model in local model and self-contained
+from model.aquila2.modeling_aquila import AquilaForCausalLM
+from model.baichuan2.modeling_baichuan import BaichuanForCausalLM
+from model.qwen.modeling_qwen import QWenLMHeadModel
+from model.chatglm2.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration2
+from model.chatglm3.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration3
+
+# from model.phi.modeling_mixformer_sequential import MixFormerSequentialForCausalLM
+
+MODEL_TYPES = {
+    "aquila2": AquilaForCausalLM,
+    "baichuan": BaichuanForCausalLM,
+    "chatglm2": ChatGLMForConditionalGeneration2,
+    "chatglm3": ChatGLMForConditionalGeneration3,
+    "code_llama": LlamaForCausalLM,
+    "deepseek": LlamaForCausalLM,
+    "gpt_neox": GPTNeoXForCausalLM,
+    "llama": LlamaForCausalLM,
+    "mistral": MistralForCausalLM,
+    "mixtral": MixtralForCausalLM,
+    "phi": PhiForCausalLM,
+    "qwen": QWenLMHeadModel,
+    "starcoder": GPTBigCodeForCausalLM,
+    "qwen2": Qwen2ForCausalLM,
+    "gemma": GemmaForCausalLM,
+    "qwen2_moe": Qwen2MoeForCausalLM,
+    "starcoder2": Starcoder2ForCausalLM,
+    "deepseek_v2": DeepseekV2ForCausalLM,
+}
+
+SUPPORT_IN_TRANSFORMERS = [
+    "code_llama",
+    "llama",
+    "deepseek",
+    "mistral",
+    "mixtral",
+    "gpt_neox",
+    "phi",
+    "starcoder",
+    "qwen2",
+    "qwen2_moe",
+    "gemma",
+    "starcoder2",
+    "deepseek_v2",
+]
diff --git a/mftcoder_accelerate/src/xxpo/custom_callbacks.py b/mftcoder_accelerate/src/xxpo/custom_callbacks.py
new file mode 100644
index 0000000..f38fa70
--- /dev/null
+++ b/mftcoder_accelerate/src/xxpo/custom_callbacks.py
@@ -0,0 +1,99 @@
+"""
+Customized Callbacks to use with the Trainer class and customize the training loop.
+"""
+
+import copy
+import dataclasses
+import json
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+from tqdm.auto import tqdm
+
+from transformers.trainer_utils import IntervalStrategy, has_length
+from transformers.training_args import TrainingArguments
+from transformers.utils import logging
+from transformers import TrainerCallback
+
+logger = logging.get_logger(__name__)
+
+
+class CustomProgressCallback(TrainerCallback):
+    """
+    A [`TrainerCallback`] that displays the progress of training or evaluation.
+    """
+
+    def __init__(self):
+        self.training_bar = None
+        self.prediction_bar = None
+
+    def on_train_begin(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
+        self.current_step = 0
+
+    def on_step_end(self, args, state, control, **kwargs):
+        if state.is_world_process_zero and state.global_step % args.logging_steps == 0:
+            self.training_bar.update(args.logging_steps)
+            self.current_step = state.global_step
+        # pass
+
+    def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
+        # if state.is_world_process_zero and has_length(eval_dataloader):
+        #     if self.prediction_bar is None:
+        #         self.prediction_bar = tqdm(
+        #             total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True
+        #         )
+        #     self.prediction_bar.update(1)
+        pass
+
+    def on_evaluate(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            if self.prediction_bar is not None:
+                self.prediction_bar.close()
+            self.prediction_bar = None
+
+    def on_predict(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            if self.prediction_bar is not None:
+                self.prediction_bar.close()
+            self.prediction_bar = None
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        if state.is_world_process_zero and self.training_bar is not None:
+            # avoid modifying the logs object as it is shared between callbacks
+            logs = copy.deepcopy(logs)
+            # _ = logs.pop("total_flos", None)
+            # round numbers so that it looks better in console
+            if "epoch" in logs:
+                logs["epoch"] = round(logs["epoch"], 2)
+            # self.training_bar.write(str(logs))
+            logger.info(logs)
+
+    def on_train_end(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            self.training_bar.close()
+            self.training_bar = None
+
+
+class PrinterCallback(TrainerCallback):
+    """
+    A bare [`TrainerCallback`] that just prints the logs.
+    """
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        _ = logs.pop("total_flos", None)
+        if state.is_local_process_zero:
+            print(logs)
+
+
+class LogCallback(TrainerCallback):
+    """
+    A bare [`TrainerCallback`] that just prints the logs.
+    """
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        _ = logs.pop("total_flos", None)
+        if state.is_local_process_zero:
+            logger.info(logs)
\ No newline at end of file
diff --git a/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py b/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py
new file mode 100644
index 0000000..4c93520
--- /dev/null
+++ b/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py
@@ -0,0 +1,484 @@
+"""
+# @author qumu
+# @date 2023/12/11
+# @module mft_accelerate.py
+
+Accelerate + DeepSpeed/FSDP + QLoRA/LoRA/Full + DPO/RPO/ORPO
+
+Entry
+"""
+
+import os
+import sys
+import argparse
+import math
+import logging
+import json
+import time
+from datetime import timedelta
+from tqdm.auto import tqdm
+from dataclasses import dataclass
+from typing import Dict, Optional, Union, List
+
+import datasets
+from datasets import Dataset, load_dataset, concatenate_datasets
+
+import torch
+from torch.utils.data import DataLoader
+from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
+
+import transformers
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    get_linear_schedule_with_warmup,
+    set_seed,
+    BitsAndBytesConfig,
+    get_scheduler,
+)
+from peft import (
+    LoraConfig,
+    TaskType,
+    get_peft_model,
+    prepare_model_for_kbit_training,
+    PeftModel,
+)
+from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration
+from accelerate.logging import get_logger
+from accelerate.utils import InitProcessGroupKwargs
+
+# insert src as import path
+current_path = os.path.abspath(__file__)
+parent_dir = os.path.dirname(os.path.dirname(current_path))
+sys.path.insert(0, parent_dir)
+
+from tokenizer import build_tokenizer
+
+from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK
+from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS
+
+logger = get_logger(__name__)
+
+
+from trl import (
+    DPOConfig,
+    DPOTrainer,
+    ORPOConfig,
+    ORPOTrainer,
+    ModelConfig,
+    get_kbit_device_map,
+    get_peft_config,
+    get_quantization_config,
+)
+from transformers.trainer_callback import (
+    CallbackHandler,
+    DefaultFlowCallback,
+    PrinterCallback,
+    ProgressCallback,
+    TrainerCallback,
+    TrainerControl,
+    TrainerState,
+)
+
+from xxpo.xxpo_arguments import XXPOTrainArgs
+from xxpo.custom_callbacks import CustomProgressCallback
+from xxpo.custom_callbacks import LogCallback
+
+
+def pprint_args(args, accelerator):
+    # 计算所有键的最大字符串长度
+    max_key_length = max(len(str(key)) for key in vars(args).keys())
+
+    message = ""
+    message += "====" * 60 + "\n"
+    message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n"
+    message += "====" * 60 + "\n"
+    accelerator.print(message)
+    accelerator.print("GPU: {}".format(torch.cuda.current_device()))
+
+
+def prepare_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--train_config", type=str, default=None)
+
+    parser.add_argument("--data_paths", type=str, default=None)
+    parser.add_argument("--output_dir", type=str, default=None)
+    parser.add_argument("--tb_dir", type=str, default=None)
+    parser.add_argument("--pretrained_model_path", type=str, default=None)
+    parser.add_argument("--micro_batch_size", type=int, default=None)
+    parser.add_argument("--model_type", type=str, default=None)
+    parser.add_argument("--distributed_type", type=str, default="deepspeed")
+
+    parsed = parser.parse_args()
+    # get json configs
+    with open(parsed.train_config, "r") as f:
+        train_config = json.load(f)
+
+    # parse args from cofig.json
+    args = XXPOTrainArgs(**train_config)
+
+    # override args by cli arguments
+    if parsed.data_paths:
+        args.data_paths = parsed.data_paths
+    if parsed.output_dir:
+        args.output_dir = parsed.output_dir
+    if parsed.tb_dir:
+        args.tb_dir = parsed.tb_dir
+    if parsed.pretrained_model_path:
+        args.pretrained_model_path = parsed.pretrained_model_path
+        args.vocab_file = parsed.pretrained_model_path
+    if parsed.micro_batch_size:
+        args.per_device_train_batch_size = parsed.micro_batch_size
+        args.per_device_eval_batch_size = parsed.micro_batch_size
+    if parsed.model_type:
+        args.model_type = parsed.model_type
+
+    args.distributed_type = parsed.distributed_type
+
+    # refactor args
+
+    if args.peft_type == "qlora":
+        print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'")
+        args.quantization = "4bit"
+    else:
+        args.quantization = None
+
+    args.vocab_file = args.pretrained_model_path
+
+    return args
+
+
+def get_model(args, accelerator):
+    ModelClass = MODEL_TYPES[args.model_type]
+    if args.model_type in SUPPORT_IN_TRANSFORMERS:
+        accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers")
+        model = ModelClass.from_pretrained(
+            args.pretrained_model_path,
+            attn_implementation=args.attn_implementation,
+            torch_dtype=torch.bfloat16,
+            # device_map=get_kbit_device_map() if args.quantization == "4bit" else None,
+            quantization_config=(
+                BitsAndBytesConfig(
+                    load_in_4bit=(args.quantization == "4bit"),
+                    bnb_4bit_compute_dtype=torch.bfloat16,
+                    bnb_4bit_use_double_quant=True,
+                    bnb_4bit_quant_type="nf4",
+                    bnb_4bit_quant_storage=torch.bfloat16,
+                )
+                if args.quantization == "4bit"
+                else None
+            ),
+        )
+    else:
+        accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code")
+        model = ModelClass.from_pretrained(
+            args.pretrained_model_path,
+            torch_dtype=torch.bfloat16,
+            quantization_config=(
+                BitsAndBytesConfig(
+                    load_in_4bit=(args.quantization == "4bit"),
+                    bnb_4bit_compute_dtype=torch.bfloat16,
+                    bnb_4bit_use_double_quant=True,
+                    bnb_4bit_quant_type="nf4",
+                    bnb_4bit_quant_storage=torch.bfloat16,
+                )
+                if args.quantization == "4bit"
+                else None
+            ),
+        )
+
+    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
+
+    return model
+
+
+def chatml_to_dpo_format(
+    data_file: str,
+    tokenizer,
+    sanity_check: bool = False,
+    cache_dir: Optional[str] = None,
+    num_proc=16,
+) -> Dataset:
+    """Load the standard-paired dataset from Hugging Face and convert it to the necessary format.
+
+    The dataset is converted to a dictionary with the following structure:
+    {
+        'chosen': List[dict], chatml
+        'rejected': List[dict], chatml
+    }
+    """
+
+    dataset = load_dataset(
+        "json",
+        split="train",
+        data_files=data_file,
+        cache_dir=cache_dir,
+        verification_mode="no_checks",
+    )
+    original_columns = dataset.column_names
+
+    if sanity_check:
+        dataset = dataset.select(range(min(len(dataset), 100)))
+
+    def process(samples):
+        samples["prompt"] = [
+            tokenizer.apply_chat_template(chosen[:-1], tokenize=False, add_generation_prompt=True)
+            for chosen in samples["chosen"]
+        ]
+        samples["chosen"] = [chosen[-1]["content"] + tokenizer.eos_token for chosen in samples["chosen"]]
+        samples["rejected"] = [rejected[-1]["content"] + tokenizer.eos_token for rejected in samples["rejected"]]
+        return samples
+
+    return dataset.map(
+        process,
+        batched=True,
+        num_proc=num_proc,
+        # remove_columns=original_columns,
+    )
+
+
+def main():
+    t0 = time.time()
+    # os.environ["TOKENIZERS_PARALLELISM"] = "false"
+    os.environ["HF_HUB_OFFLINE"] = "false"
+    # get input args, set TASK2ID, ID2TASK, refactor args
+    args = prepare_args()
+
+    # fix randomness
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    # define accelerator
+    init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds))
+
+    if args.distributed_type and args.distributed_type.lower() == "fsdp":
+        fsdp_plugin = FullyShardedDataParallelPlugin(
+            # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
+            # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
+            limit_all_gathers=True,
+            sync_module_states=True,
+            use_orig_params=True,
+            cpu_offload=False,
+        )
+        accelerator = Accelerator(
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            fsdp_plugin=fsdp_plugin,
+            dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True),
+            kwargs_handlers=[init_process_kwargs],
+        )
+    else:
+        accelerator = Accelerator(
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True),
+            kwargs_handlers=[init_process_kwargs],
+        )
+
+    # print key infos
+    accelerator.print("In dpo_accelerate.py, sys path:", sys.path)
+    accelerator.print(f"transformers.__version__: {transformers.__version__}")
+
+    # get world_size
+    args.world_size = accelerator.num_processes
+
+    # backup args
+    pprint_args(args, accelerator)
+    if accelerator.is_main_process:
+        if not os.path.exists(args.output_dir):
+            os.makedirs(args.output_dir)
+        with open(os.path.join(args.output_dir, "args.json"), "w") as f:
+            json.dump(args.dict(), f, indent=2)
+
+    # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest
+
+    # logger
+    logging.basicConfig(
+        format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state, main_process_only=False)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_info()
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+
+    # get global_rank and local rank for current process
+    global_rank = accelerator.process_index
+    local_rank = accelerator.local_process_index
+    print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}")
+
+    # 1. dataset
+
+    # build tokenizer
+    tokenizer = build_tokenizer(args)
+    # tokenizer.chat_template = MFTCoder_template
+
+    # Load the dpo dataset
+    all_datasets = []
+    # print(args.data_paths, type(args.data_paths))
+    if isinstance(args.data_paths, str):
+        args.data_paths = list(args.data_paths[1:-1].split(","))
+        # print(f"DATA_PATHS: {args.data_paths}")
+    for data_file in args.data_paths:
+        ds = chatml_to_dpo_format(data_file=data_file, tokenizer=tokenizer, sanity_check=args.sanity_check)
+        all_datasets.append(ds)
+
+    all_dataset = concatenate_datasets(all_datasets)
+    # all_dataset = all_dataset.filter(
+    #     lambda x: len(x["prompt"]) + len(x["chosen"]) <= args.max_length
+    #     and len(x["prompt"]) + len(x["rejected"]) <= args.max_length
+    # )
+    accelerator.print(f"Length of all_dataset: {len(all_dataset)}")
+
+    # split train/eval dataset
+    splits = [float(s) for s in args.data_split.split(",")][:2]
+    print(f"data splits: {splits}")
+
+    all_dataset = all_dataset.train_test_split(test_size=splits[1] / sum(splits), shuffle=True, seed=args.seed)
+    all_dataset.flatten_indices()
+
+    train_dataset, eval_dataset = all_dataset["train"], all_dataset["test"]
+    accelerator.print(f"Length of train_dataset: {len(train_dataset)}\nLength of eval_dataset: {len(eval_dataset)}")
+    print(eval_dataset[0])
+    t1 = time.time()
+    logger.info(f"dataset loading time: {t1 - t0:.4f}")
+
+    # cuda memory
+    free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
+    max_memory = f"{free_in_GB - 2}GB"
+    n_gpus = torch.cuda.device_count()
+    max_memory = {i: max_memory for i in range(n_gpus)}
+    accelerator.print("max memory: ", max_memory, n_gpus)
+
+    # target_modules, default all-linear for all linear layers
+    if args.target_modules:
+        target_modules = args.target_modules
+    else:
+        target_modules = "all-linear"
+
+    # peft config
+    if args.peft_type:
+        peft_config = LoraConfig(
+            task_type=TaskType.CAUSAL_LM,
+            inference_mode=False,
+            r=args.lora_rank,
+            lora_alpha=args.lora_alpha,
+            lora_dropout=args.lora_dropout,
+            target_modules=target_modules,
+            bias="lora_only",
+        )
+    else:
+        peft_config = None
+
+    # creating base model
+    model = get_model(args, accelerator)
+    if args.ignore_bias_buffers:
+        # torch distributed hack
+        model._ddp_params_and_buffers_to_ignore = [
+            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
+        ]
+    accelerator.print("Model load_in_4bit: ", args.quantization == "4bit")
+
+    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
+    if hasattr(model.config, "use_logn_attn"):
+        model.config.use_logn_attn = False  # special for qwen model
+    # load balance for moe training
+    if hasattr(model.config, "output_router_logits"):
+        model.config.output_router_logits = True
+    model_config = model.config
+    accelerator.print(model.config)
+
+    t2 = time.time()
+    if accelerator.is_main_process:
+        logging.info(f"model loading time: {t2 - t1:.4f}")
+
+    # 4. initialize training arguments:
+    if args.xxpo == "dpo":
+        ConfigClass = DPOConfig
+    elif args.xxpo == "orpo":
+        ConfigClass = ORPOConfig
+    logging.info(f"{args.xxpo} Used.")
+
+    training_args = ConfigClass(
+        beta=args.beta,
+        rpo_alpha=args.rpo_alpha,
+        per_device_train_batch_size=args.per_device_train_batch_size,
+        per_device_eval_batch_size=args.per_device_eval_batch_size,
+        max_steps=args.max_steps,
+        num_train_epochs=args.num_train_epochs,
+        logging_steps=args.logging_steps,
+        save_strategy="steps",
+        eval_strategy="steps",
+        save_steps=args.save_steps,
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        gradient_checkpointing=args.gradient_checkpointing,
+        learning_rate=args.learning_rate,
+        eval_steps=args.eval_steps,
+        output_dir=args.output_dir,
+        report_to="tensorboard",
+        logging_dir=args.tb_dir,
+        max_prompt_length=args.max_prompt_length,
+        max_length=args.max_length,
+        lr_scheduler_type=args.lr_scheduler_type,
+        warmup_steps=args.warmup_steps,
+        optim=args.optimizer_type,
+        bf16=True,
+        remove_unused_columns=False,
+        run_name="",
+        gradient_checkpointing_kwargs=dict(use_reentrant=args.gradient_checkpointing_use_reentrant),
+        seed=args.seed,
+        dataset_num_proc=args.dataset_num_proc,
+        disable_tqdm=args.disable_tqdm,
+        save_only_model=args.save_only_model,
+        save_total_limit=args.saving_limit,
+    )
+
+    # 5. initialize the DPO trainer
+    if not args.peft_type and args.xxpo == "dpo":
+        model_ref = get_model(args, accelerator)
+        model_ref.config.use_cache = False  # silence the warnings. Please re-enable for inference!
+    else:
+        model_ref = None
+
+    if args.xxpo == "dpo":
+        xxpo_trainer = DPOTrainer(
+            model,
+            ref_model=model_ref,
+            args=training_args,
+            train_dataset=train_dataset,
+            eval_dataset=eval_dataset,
+            tokenizer=tokenizer,
+            peft_config=peft_config,
+        )
+    elif args.xxpo == "orpo":
+        xxpo_trainer = ORPOTrainer(
+            model,
+            args=training_args,
+            train_dataset=train_dataset,
+            eval_dataset=eval_dataset,
+            tokenizer=tokenizer,
+            peft_config=peft_config,
+        )
+
+    # callbacks
+    if args.disable_tqdm:
+        xxpo_trainer.remove_callback(PrinterCallback)
+        xxpo_trainer.add_callback(LogCallback)
+    else:
+        xxpo_trainer.remove_callback(ProgressCallback)
+        xxpo_trainer.add_callback(CustomProgressCallback)
+
+    # 6. train
+    xxpo_trainer.train()
+
+    # 7. save
+    output_dir = os.path.join(args.output_dir, "epoch_final")
+    xxpo_trainer.save_model(output_dir)
+    # dpo_trainer.model.save_pretrained(output_dir)
+    logger.info(f"Training Finished!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/mftcoder_accelerate/src/xxpo/xxpo_arguments.py b/mftcoder_accelerate/src/xxpo/xxpo_arguments.py
new file mode 100644
index 0000000..2b4c876
--- /dev/null
+++ b/mftcoder_accelerate/src/xxpo/xxpo_arguments.py
@@ -0,0 +1,170 @@
+"""
+# @author Chaoyu Chen
+# @date 2023/10/19
+
+training arguments
+"""
+
+from dataclasses import dataclass, asdict
+from typing import List, Union
+
+
+@dataclass
+class XXPOTrainArgs:
+    # train data paths on shared FS
+    data_paths: Union[str, List[str]]
+
+    # output dir for saving adaptors in peft or full ckpts in full-parameter training
+    output_dir: str
+
+    # tensorboard dir for saving tensorboard logs
+    tb_dir: str
+
+    # pretrained_model_path, on which is the model you want to train
+    pretrained_model_path: str
+
+    # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2
+    model_type: str
+
+    # train/valid/test split
+    data_split: str = "98,2,0"
+
+    # lora or qlora or None(for full-parameter training)
+    peft_type: Union[None, str] = "qlora"
+
+    # if qlora, 4bit will be set, else None
+    quantization: Union[None, str] = "4bit"
+
+    # lora rank, the bigger, the more trainalbe parameters
+    lora_rank: int = 96
+
+    # lora alpha
+    lora_alpha: int = 32
+
+    # lora dropout
+    lora_dropout: float = 0.05
+
+    # lora targeting modules
+    target_modules: Union[None, str, List[str]] = None
+
+    # dpo or orpo
+    xxpo: str = "dpo"
+
+    # dpo/orpo beta
+    beta: float = 0.1
+
+    rpo_alpha: Union[None, float] = None
+
+    # mircro train batch size
+    per_device_train_batch_size: int = 8
+
+    # micro eval batch size, always same as micro train batch size
+    per_device_eval_batch_size: int = 8
+
+    # HF AutoTokenizer is supported, maybe more types
+    tokenizer_type: str = "AutoTokenizer"
+
+    # initial lr
+    learning_rate: float = 5e-5
+
+    # minimum lr
+    min_lr: float = 5e-6
+
+    # weight decay
+    weight_decay: float = 0.01
+
+    # gradient_accumulation_steps
+    gradient_accumulation_steps: int = 1
+
+    # lr_scheduler_type
+    lr_scheduler_type: str = "cosine"
+
+    # optimizer_type
+    optimizer_type: str = "adamw_torch"
+    # optimizer_type: str = "paged_adamw_32bit"
+
+    # gradient_checkpointing
+    gradient_checkpointing: bool = True
+    gradient_checkpointing_use_reentrant: bool = False
+
+    # num of warmup_steps
+    warmup_steps: Union[int, float] = 0.05
+
+    # num_train_epochs
+    num_train_epochs: int = 4
+
+    # seed for reproducing
+    seed: int = 1234
+
+    # seq_length, context length
+    seq_length: int = 4096
+
+    save_only_model: bool = True
+
+    # path of adaptor which is resumed from, None for not resuming training
+    resume_from_checkpoint: Union[None, str] = None
+
+    # auto resume from latest ckpt if job restarted
+    auto_resume: bool = True
+
+    # num of steps for logging training loss
+    logging_steps: int = 10
+
+    # num of steps for saving ckpt
+    save_steps: int = 100
+
+    # num of steps for evaluation(eval_loss), better same as checkpointing steps
+    eval_steps: int = 100
+
+    # max train steps, if None, depends on num_train_epochs
+    max_steps: int = -1
+
+    # if checkpointing every epoch, maybe True in sst
+    epoch_checkpointing: bool = False
+
+    # shuffle before train/valid split
+    shuffle_before_split: bool = True
+
+    # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point
+    early_stopping: bool = True
+    early_stopping_stall_num: int = 5
+
+    # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota.
+    saving_limit: Union[None, int] = None
+
+    # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2}
+    attn_implementation: str = "flash_attention_2"
+
+    # tokenizer chat template, if None, will use MFTCoder template
+    chat_template: Union[None, str] = None
+
+    distributed_type: Union[None, str] = None
+
+    init_timeout_seconds: Union[None, int] = 3600
+
+    make_vocab_size_divisible_by: int = 32
+    model_parallel_size: int = 1
+    use_slow_tokenizer: bool = False
+    world_size: int = 8
+
+    # max prompt string length and whole str length
+    max_prompt_length: Union[None, int] = 2048
+    max_length: Union[None, int] = 4096
+
+    # num of process processing dataset
+    dataset_num_proc: int = 1
+
+    # model_dtype[float16, bfloat16, float] for loading
+    dtype: str = "bfloat16"
+
+    # instrumentation
+    disable_tqdm: bool = False
+    sanity_check: bool = False
+
+    # debug argument for distributed training
+    # "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
+    # "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
+    ignore_bias_buffers: bool = True
+
+    def dict(self):
+        return {k: str(v) for k, v in asdict(self).items()}
diff --git a/mft_atorch/.gitignore b/mftcoder_atorch/.gitignore
similarity index 100%
rename from mft_atorch/.gitignore
rename to mftcoder_atorch/.gitignore
diff --git a/mft_atorch/README.md b/mftcoder_atorch/README.md
similarity index 100%
rename from mft_atorch/README.md
rename to mftcoder_atorch/README.md
diff --git a/mft_atorch/README_cn.md b/mftcoder_atorch/README_cn.md
similarity index 100%
rename from mft_atorch/README_cn.md
rename to mftcoder_atorch/README_cn.md
diff --git a/mft_atorch/arguments/get_arguments.py b/mftcoder_atorch/arguments/get_arguments.py
similarity index 100%
rename from mft_atorch/arguments/get_arguments.py
rename to mftcoder_atorch/arguments/get_arguments.py
diff --git a/mft_peft_hf/src/data/Makefile b/mftcoder_atorch/data/Makefile
similarity index 100%
rename from mft_peft_hf/src/data/Makefile
rename to mftcoder_atorch/data/Makefile
diff --git a/mft_peft_hf/src/data/__init__.py b/mftcoder_atorch/data/__init__.py
similarity index 100%
rename from mft_peft_hf/src/data/__init__.py
rename to mftcoder_atorch/data/__init__.py
diff --git a/mft_atorch/data/gpt2_multi_task_dataset.py b/mftcoder_atorch/data/gpt2_multi_task_dataset.py
similarity index 100%
rename from mft_atorch/data/gpt2_multi_task_dataset.py
rename to mftcoder_atorch/data/gpt2_multi_task_dataset.py
diff --git a/mft_peft_hf/src/data/helpers.cpp b/mftcoder_atorch/data/helpers.cpp
similarity index 100%
rename from mft_peft_hf/src/data/helpers.cpp
rename to mftcoder_atorch/data/helpers.cpp
diff --git a/mft_atorch/data/helpers.cpython-38-x86_64-linux-gnu.so b/mftcoder_atorch/data/helpers.cpython-38-x86_64-linux-gnu.so
similarity index 100%
rename from mft_atorch/data/helpers.cpython-38-x86_64-linux-gnu.so
rename to mftcoder_atorch/data/helpers.cpython-38-x86_64-linux-gnu.so
diff --git a/mft_atorch/data/preprocess_data.py b/mftcoder_atorch/data/preprocess_data.py
similarity index 100%
rename from mft_atorch/data/preprocess_data.py
rename to mftcoder_atorch/data/preprocess_data.py
diff --git a/mft_atorch/data/tokenization/lm_dataformat.py b/mftcoder_atorch/data/tokenization/lm_dataformat.py
similarity index 100%
rename from mft_atorch/data/tokenization/lm_dataformat.py
rename to mftcoder_atorch/data/tokenization/lm_dataformat.py
diff --git a/mft_atorch/data/tokenization/preprocess_data.py b/mftcoder_atorch/data/tokenization/preprocess_data.py
similarity index 100%
rename from mft_atorch/data/tokenization/preprocess_data.py
rename to mftcoder_atorch/data/tokenization/preprocess_data.py
diff --git a/mft_atorch/model/__init__.py b/mftcoder_atorch/model/__init__.py
similarity index 100%
rename from mft_atorch/model/__init__.py
rename to mftcoder_atorch/model/__init__.py
diff --git a/mft_atorch/model/build_model.py b/mftcoder_atorch/model/build_model.py
similarity index 100%
rename from mft_atorch/model/build_model.py
rename to mftcoder_atorch/model/build_model.py
diff --git a/mft_peft_hf/src/model/gpt_neox/__init__.py b/mftcoder_atorch/model/gpt_neox/__init__.py
similarity index 100%
rename from mft_peft_hf/src/model/gpt_neox/__init__.py
rename to mftcoder_atorch/model/gpt_neox/__init__.py
diff --git a/mft_atorch/model/gpt_neox/config.json b/mftcoder_atorch/model/gpt_neox/config.json
similarity index 100%
rename from mft_atorch/model/gpt_neox/config.json
rename to mftcoder_atorch/model/gpt_neox/config.json
diff --git a/mft_atorch/model/gpt_neox/configuration_gpt_neox.py b/mftcoder_atorch/model/gpt_neox/configuration_gpt_neox.py
similarity index 100%
rename from mft_atorch/model/gpt_neox/configuration_gpt_neox.py
rename to mftcoder_atorch/model/gpt_neox/configuration_gpt_neox.py
diff --git a/mft_peft_hf/src/model/gpt_neox/generation_config.json b/mftcoder_atorch/model/gpt_neox/generation_config.json
similarity index 100%
rename from mft_peft_hf/src/model/gpt_neox/generation_config.json
rename to mftcoder_atorch/model/gpt_neox/generation_config.json
diff --git a/mft_atorch/model/gpt_neox/modeling_gpt_neox.py b/mftcoder_atorch/model/gpt_neox/modeling_gpt_neox.py
similarity index 100%
rename from mft_atorch/model/gpt_neox/modeling_gpt_neox.py
rename to mftcoder_atorch/model/gpt_neox/modeling_gpt_neox.py
diff --git a/mft_atorch/model/gpt_neox/tokenization_gpt_neox_fast.py b/mftcoder_atorch/model/gpt_neox/tokenization_gpt_neox_fast.py
similarity index 100%
rename from mft_atorch/model/gpt_neox/tokenization_gpt_neox_fast.py
rename to mftcoder_atorch/model/gpt_neox/tokenization_gpt_neox_fast.py
diff --git a/mft_atorch/model/peft/__init__.py b/mftcoder_atorch/model/peft/__init__.py
similarity index 100%
rename from mft_atorch/model/peft/__init__.py
rename to mftcoder_atorch/model/peft/__init__.py
diff --git a/mft_atorch/model/peft/modeling_peft.py b/mftcoder_atorch/model/peft/modeling_peft.py
similarity index 100%
rename from mft_atorch/model/peft/modeling_peft.py
rename to mftcoder_atorch/model/peft/modeling_peft.py
diff --git a/mft_atorch/model/peft/tuner/__init__.py b/mftcoder_atorch/model/peft/tuner/__init__.py
similarity index 100%
rename from mft_atorch/model/peft/tuner/__init__.py
rename to mftcoder_atorch/model/peft/tuner/__init__.py
diff --git a/mft_atorch/model/peft/tuner/adalora.py b/mftcoder_atorch/model/peft/tuner/adalora.py
similarity index 100%
rename from mft_atorch/model/peft/tuner/adalora.py
rename to mftcoder_atorch/model/peft/tuner/adalora.py
diff --git a/mft_atorch/model/peft/tuner/bitfit.py b/mftcoder_atorch/model/peft/tuner/bitfit.py
similarity index 100%
rename from mft_atorch/model/peft/tuner/bitfit.py
rename to mftcoder_atorch/model/peft/tuner/bitfit.py
diff --git a/mft_atorch/model/peft/tuner/pe_base_model.py b/mftcoder_atorch/model/peft/tuner/pe_base_model.py
similarity index 100%
rename from mft_atorch/model/peft/tuner/pe_base_model.py
rename to mftcoder_atorch/model/peft/tuner/pe_base_model.py
diff --git a/mft_atorch/model/peft/tuner/roem.py b/mftcoder_atorch/model/peft/tuner/roem.py
similarity index 100%
rename from mft_atorch/model/peft/tuner/roem.py
rename to mftcoder_atorch/model/peft/tuner/roem.py
diff --git a/mft_atorch/model/peft/tuner/routelora.py b/mftcoder_atorch/model/peft/tuner/routelora.py
similarity index 100%
rename from mft_atorch/model/peft/tuner/routelora.py
rename to mftcoder_atorch/model/peft/tuner/routelora.py
diff --git a/mft_atorch/model/peft/tuner/unipelt.py b/mftcoder_atorch/model/peft/tuner/unipelt.py
similarity index 100%
rename from mft_atorch/model/peft/tuner/unipelt.py
rename to mftcoder_atorch/model/peft/tuner/unipelt.py
diff --git a/mft_atorch/model/peft/utils/__init__.py b/mftcoder_atorch/model/peft/utils/__init__.py
similarity index 100%
rename from mft_atorch/model/peft/utils/__init__.py
rename to mftcoder_atorch/model/peft/utils/__init__.py
diff --git a/mft_atorch/model/peft/utils/config.py b/mftcoder_atorch/model/peft/utils/config.py
similarity index 100%
rename from mft_atorch/model/peft/utils/config.py
rename to mftcoder_atorch/model/peft/utils/config.py
diff --git a/mft_atorch/model/peft/utils/mapping.py b/mftcoder_atorch/model/peft/utils/mapping.py
similarity index 100%
rename from mft_atorch/model/peft/utils/mapping.py
rename to mftcoder_atorch/model/peft/utils/mapping.py
diff --git a/mft_atorch/model/peft/utils/others.py b/mftcoder_atorch/model/peft/utils/others.py
similarity index 100%
rename from mft_atorch/model/peft/utils/others.py
rename to mftcoder_atorch/model/peft/utils/others.py
diff --git a/mft_atorch/tokenizer/__init__.py b/mftcoder_atorch/tokenizer/__init__.py
similarity index 100%
rename from mft_atorch/tokenizer/__init__.py
rename to mftcoder_atorch/tokenizer/__init__.py
diff --git a/mft_atorch/tokenizer/gpt2_tokenization.py b/mftcoder_atorch/tokenizer/gpt2_tokenization.py
similarity index 100%
rename from mft_atorch/tokenizer/gpt2_tokenization.py
rename to mftcoder_atorch/tokenizer/gpt2_tokenization.py
diff --git a/mft_atorch/tokenizer/tokenizer.py b/mftcoder_atorch/tokenizer/tokenizer.py
similarity index 100%
rename from mft_atorch/tokenizer/tokenizer.py
rename to mftcoder_atorch/tokenizer/tokenizer.py
diff --git a/mft_atorch/tokenizer/train_tokenizer.py b/mftcoder_atorch/tokenizer/train_tokenizer.py
similarity index 100%
rename from mft_atorch/tokenizer/train_tokenizer.py
rename to mftcoder_atorch/tokenizer/train_tokenizer.py
diff --git a/mft_atorch/train/__init__.py b/mftcoder_atorch/train/__init__.py
similarity index 100%
rename from mft_atorch/train/__init__.py
rename to mftcoder_atorch/train/__init__.py
diff --git a/mft_atorch/train/run_gpt_mft.sh b/mftcoder_atorch/train/run_gpt_mft.sh
similarity index 100%
rename from mft_atorch/train/run_gpt_mft.sh
rename to mftcoder_atorch/train/run_gpt_mft.sh
diff --git a/mft_atorch/train/run_gpt_mft_peft.sh b/mftcoder_atorch/train/run_gpt_mft_peft.sh
similarity index 100%
rename from mft_atorch/train/run_gpt_mft_peft.sh
rename to mftcoder_atorch/train/run_gpt_mft_peft.sh
diff --git a/mft_atorch/train/run_train.py b/mftcoder_atorch/train/run_train.py
similarity index 100%
rename from mft_atorch/train/run_train.py
rename to mftcoder_atorch/train/run_train.py
diff --git a/mft_atorch/train/trainer/atorch_trainer.py b/mftcoder_atorch/train/trainer/atorch_trainer.py
similarity index 100%
rename from mft_atorch/train/trainer/atorch_trainer.py
rename to mftcoder_atorch/train/trainer/atorch_trainer.py
diff --git a/mft_atorch/utils/__init__.py b/mftcoder_atorch/utils/__init__.py
similarity index 100%
rename from mft_atorch/utils/__init__.py
rename to mftcoder_atorch/utils/__init__.py
diff --git a/mft_atorch/utils/auto_accelerate_utils.py b/mftcoder_atorch/utils/auto_accelerate_utils.py
similarity index 100%
rename from mft_atorch/utils/auto_accelerate_utils.py
rename to mftcoder_atorch/utils/auto_accelerate_utils.py
diff --git a/mft_atorch/utils/common_utils.py b/mftcoder_atorch/utils/common_utils.py
similarity index 100%
rename from mft_atorch/utils/common_utils.py
rename to mftcoder_atorch/utils/common_utils.py
diff --git a/mft_atorch/utils/learning_rates.py b/mftcoder_atorch/utils/learning_rates.py
similarity index 100%
rename from mft_atorch/utils/learning_rates.py
rename to mftcoder_atorch/utils/learning_rates.py
diff --git a/mft_atorch/utils/merge_base_and_lora_to_hf.py b/mftcoder_atorch/utils/merge_base_and_lora_to_hf.py
similarity index 100%
rename from mft_atorch/utils/merge_base_and_lora_to_hf.py
rename to mftcoder_atorch/utils/merge_base_and_lora_to_hf.py
diff --git a/mft_atorch/utils/vocab.json b/mftcoder_atorch/utils/vocab.json
similarity index 100%
rename from mft_atorch/utils/vocab.json
rename to mftcoder_atorch/utils/vocab.json
diff --git a/requirements.txt b/requirements.txt
index dfff561..189518b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,18 @@
-numpy
-pandas
+numpy==1.23.5
+pandas==2.2.1
+torch==2.1.0
+tensorboard==2.11.0
+deepspeed==0.14.0
+transformers==4.44.2
+accelerate==0.31.0
+peft==0.10.0
+BitsAndBytes==0.43.0
+xformers==0.0.22.post7
+datasets
+ftfy
+packaging
 einops
 sentencepiece
-deepspeed==0.9.3
-transformers==4.32.0
-accelerate==0.21.0
-peft==0.4.0
-BitsAndBytes==0.40.2
-xformers==0.0.21
 ujson
 jsonlines
 tiktoken