diff --git a/.gitignore b/.gitignore index 68729d6..2e34447 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .idea/ -.DS_Store \ No newline at end of file +.DS_Store +*.log +*/__pycache__/ +*.pyc \ No newline at end of file diff --git a/README.md b/README.md index 94aa7c9..468cc88 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning +# MFTCoder: High Accuracy and Efficiency Multi-task Fine-Tuning Framework
@@ -21,7 +21,11 @@
+ 🤗 HuggingFace + • 🤖 ModelScope + +
[[中文]](README_cn.md) [**English**] @@ -38,111 +42,143 @@ - [Models](#Models) - [Datasets](#Datasets) - [Star History](#Star-History) +- [Join Us](#Join-Us) ## News +🔥🔥🔥 [2024/10/31] We released **MFTCoder v0.5** mainly for MFTCoder-accelerate, which is now supporting preference alignment methods like **DPO/RPO/ORPO** in the new **xxpo** module, adding full-parameter continue-training in the additional **mpt** module along with its **offline_tokenization** module, updating selfpaced method to new convergence balance(CoBa) method for MFT in the original **pefts** module. + +🔥🔥🔥 [2024/10/31] Our paper [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) has been accepted by EMNLP-2024, which achieves balanced convergence across various tasks. + +🔥🔥🔥 [2024/05/20] We released **MFTCoder v0.4**, mainly for MFTCoder-accelerate. It supports **QLoRA + DeepSpeed Zero3** and **QLoRA + FSDP** as options allowing you training very large models. It now supports new models like Qwen2, Qwen2-MoE, Starcoder2, Gemma, etc. + +🔥🔥🔥 [2024/05/20] Our paper [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) has been accepted by KDD2024. + +🔥🔥🔥 [2024/05/20] [CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) has been released, achieving a pass@1 (greedy decoding) score of 73.2% on HumanEval. + +🔥🔥 [2024/01/30] The model [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) fine-tuned with MFTCoder ranks first in HuggingFace [Big Code Models LeaderBoard](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard) + +🔥🔥 [2024/01/17] We released MFTCoder v0.3.0, mainly for MFTCoder-accelerate. It now supports new models like Mixtral(MoE), DeepSeek-coder, chatglm3. It supports FSDP as an option. It also supports Self-paced Loss as a solution for convergence balance in Multitask Fine-tuning. + +🔥🔥 [2024/01/17] [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) has been released, achieving a pass@1 (greedy decoding) score of 78.7% on HumanEval. It lists as top-1 LLM on Bigcode Leardboard in terms of win-rate, the official result is going to be published later. + +🔥🔥 [2024/01/17] [CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8X7B) has been released, achieving a pass@1 (greedy decoding) score of 56.1% on HumanEval. + 🔥🔥 [2023/11/07] [MFTCoder Paper](https://arxiv.org/abs/2311.02303) has been released on Arxiv, which discloses technique details of multi-task-fine-tuning. 🔥🔥 [2023/10/20] [CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) has been released, achieving a pass@1 (greedy decoding) score of 48.8% on HumanEval, which gains 16% absolute improvement over the base model [Qwen-14b](https://huggingface.co/Qwen/Qwen-14B) 🔥🔥 [2023/09/27] [CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) has been released, achieving a pass@1 (greedy decoding) score of 54.9% on HumanEval. -🔥🔥🔥 [2023/09/26]We are pleased to announce the release of the [4-bit quantized version of CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits). Despite the quantization process, the model still achieves a remarkable 73.8% accuracy (greedy decoding) on the HumanEval pass@1 metric. +🔥🔥 [2023/09/26]We are pleased to announce the release of the [4-bit quantized version of CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits). Despite the quantization process, the model still achieves a remarkable 73.8% accuracy (greedy decoding) on the HumanEval pass@1 metric. -🔥🔥🔥 [2023/09/07]We released **CodeFuse-CodeLlama-34B**, which achieves the **74.4% Python Pass@1** (greedy decoding) and surpasses GPT4 (2023/03/15) and ChatGPT-3.5 on the [HumanEval Benchmarks](https://github.com/openai/human-eval). +🔥🔥 [2023/09/07]We released [**CodeFuse-CodeLlama-34B**](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits), which achieves the **74.4% Python Pass@1** (greedy decoding) and surpasses GPT4 (2023/03/15) and ChatGPT-3.5 on the [HumanEval Benchmarks](https://github.com/openai/human-eval). -🔥🔥 [2023/08/26]We released MFTCoder which supports finetuning Code Llama, Llama, Llama2, StarCoder, ChatGLM2, CodeGeeX2, Qwen, and GPT-NeoX models with LoRA/QLoRA. +🔥🔥 [2023/08/26]We released MFTCoder-v0.1.0 which supports finetuning Code Llama, Llama, Llama2, StarCoder, ChatGLM2, CodeGeeX2, Qwen, and GPT-NeoX models with LoRA/QLoRA. ### HumanEval Performance | Model | HumanEval(Pass@1) | Date | |:----------------------------|:-----------------:|:-------:| -| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | -|**CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | -| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | -| GPT-4(zero-shot) | 67.0% | 2023/03 | -| PanGu-Coder2 15B | 61.6% | 2023/08 | -| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | -| CodeLlama-34b-Python | 53.7% | 2023/08 | -| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | -| CodeLlama-34b | 48.8% | 2023/08 | -| GPT-3.5(zero-shot) | 48.1% | 2022/11 | -| OctoCoder | 46.2% | 2023/08 | -| StarCoder-15B | 33.6% | 2023/05 | -| QWen-14B | 32.3% | 2023/10 | +| **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 | +| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | +| **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | +| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 | +| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | +| GPT-4(zero-shot) | 67.0% | 2023/03 | +| PanGu-Coder2 15B | 61.6% | 2023/08 | +| **CodeFuse-Mixtral-8x7B** | **56.1%** | 2024/01 | +| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | +| CodeLlama-34b-Python | 53.7% | 2023/08 | +| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | +| CodeLlama-34b | 48.8% | 2023/08 | +| GPT-3.5(zero-shot) | 48.1% | 2022/11 | +| OctoCoder | 46.2% | 2023/08 | +| StarCoder-15B | 33.6% | 2023/05 | +| QWen-14B | 32.3% | 2023/10 | ## Articles -[MFT Arxiv paper](https://arxiv.org/abs/2311.02303) +[MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning (KDD2024)](https://arxiv.org/abs/2311.02303) ## Introduction -**High Accuracy and efficiency multi-task fine-tuning framework for Code LLMs.** +**High Accuracy and efficiency Multi-task Fine-tuning framework for Code LLMs.** + +**MFTCoder** is an open-source project of CodeFuse for accurate and efficient Multi-task Fine-tuning(MFT) on Large Language Models(LLMs), especially on Code-LLMs(large language model for code tasks). +Moreover, we open source Code LLM models and code-related datasets along with the MFTCoder framework. -**CodeFuse-MFTCoder** is an open-source project of CodeFuse for multitasking Code-LLMs(large language model for code tasks), which includes models, datasets, training codebases and inference guides. In MFTCoder, we released two codebases for finetuning Large Language Models: -- ```mft_peft_hf``` is based on the HuggingFace Accelerate and deepspeed framework. -- ```mft_atorch``` is based on the [ATorch frameworks](https://github.com/intelligent-machine-learning/dlrover), which is a fast distributed training framework of LLM. +- **```MFTCoder-accelerate```** is a framework with accelerate and DeepSpeed/FSDP. All tech-stacks are open-source and vibrant. We highly recommend you try this framework and make your fintuning accurate and efficient. +- ```MFTCoder-atorch``` is based on the [ATorch frameworks](https://github.com/intelligent-machine-learning/dlrover), which is a fast distributed training framework of LLM. The aim of this project is to foster collaboration and share advancements in large language models, particularly within the domain of code development. ### Frameworks - + ### Highlights :white_check_mark: **Multi-task**: Train models on multiple tasks while maintaining a balance between them. The models can even generalize to new, previously unseen tasks. :white_check_mark: **Multi-model**: It integrates state-of-the-art open-source models such as gpt-neox, llama, llama-2, baichuan, Qwen, chatglm2, and more. (These finetuned models will be released in the near future.) -:white_check_mark: **Multi-framework**: It provides support for both HuggingFace Accelerate (with deepspeed) and [ATorch](https://github.com/intelligent-machine-learning/dlrover). +:white_check_mark: **Multi-framework**: It provides support for both Accelerate (with Deepspeed and FSDP) and ATorch -:white_check_mark: **Efficient fine-tuning**: It supports LoRA and QLoRA, enabling fine-tuning of large models with minimal resources. The training speed meets the demands of almost all fine-tuning scenarios. +:white_check_mark: **Efficient fine-tuning**: It supports LoRA, QLoRA as well as Full-parameters training, enabling fine-tuning of large models with minimal resources. The training speed meets the demands of almost all fine-tuning scenarios. The main components of this project include: - Support for both SFT (Supervised FineTuning) and MFT (Multi-task FineTuning). The current MFTCoder achieves data balance among multiple tasks, and future releases will achieve a balance between task difficulty and convergence speed during training. -- Support for QLoRA instruction fine-tuning, as well as LoRA fine-tuning. -- Support for most mainstream open-source large models, particularly those relevant to Code-LLMs, such as Code-LLaMA, Starcoder, Codegeex2, Qwen, GPT-Neox, and more. +- Support for QLoRA instruction fine-tuning, LoRA fine-tuning as well as Full-parameters fine-tuning. +- Support for most mainstream open-source large models, particularly those relevant to Code-LLMs, such as DeepSeek-coder, Mistral, Mixtral, Chatglm3, Code-LLaMA, Starcoder, Codegeex2, Qwen, GPT-Neox, and more. - Support for weight merging between the LoRA adaptor and base models, simplifying the inference process. - Release of 2 high-quality code-related instruction fine-tuning datasets: [Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) and [CodeExercise-Python-27k](https://huggingface.co/datasets/codefuse-ai/CodeExercise-Python-27k). -- Release of 2 models: [CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B) and [CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B). +- Release of many Code LLMs, please refer to organizations: [codefuse-ai on huggingface](https://huggingface.co/codefuse-ai) or [codefuse-ai on modelscope](https://modelscope.cn/organization/codefuse-ai). ## Requirements -To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 11.7) along with the necessary drivers. Additionally, make sure you have installed torch (version 2.0.1). +To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 12.1) along with the necessary drivers. Additionally, make sure you have installed torch (version >= 2.1.0). Next, we have provided an init_env.sh script to simplify the installation of required packages. Execute the following command to run the script: ```bash sh init_env.sh ``` -If you require flash attention, please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention +We highly recommend training with flash attention(version >= 2.3.0), please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention ## Training -🚀 [Huggingface accelerate + deepspeed Codebase for MFT(Multi-task Finetuning)](./mft_peft_hf/README.md) +As mentioned above, we open source two training frameworks. You could refer to their own READMEs for more details as followed. -🚀 [Atorch Codebase for MFT(Multi-task Finetuning)](./mft_atorch/README.md) +If you are familiar with open source ```transformers```, ```DeepSpeed``` or ```FSDP```, we highly recommend you try: +🚀🚀 [**MFTCoder-accelerate: Accelerate + Deepspeed/FSDP Codebase for MFT(Multi-task Finetuning)**](mftcoder_accelerate/README.md) -## Models -We are excited to release the following two CodeLLMs trained by MFTCoder, now available on Hugging Face: +If you want to explore some new framework like atorch, you could check: + +🚀 [MFTCoder-atorch: Atorch Codebase for MFT(Multi-task Finetuning)](mftcoder_atorch/README.md) -| Model | Base Model | Num of examples trained | Batch Size | Seq Length | -|--------------------------------------------------------------------------------------------|--------------------|-------------------------|------------|------------| -| [🔥🔥🔥 CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 600k | 80 | 4096 | -| [🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python| | | 4096 | -| [🔥🔥🔥 CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | Starcoder | 600k | 256 | 4096 | -| [🔥🔥🔥 CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 1100k | 256 | 4096 | -| [🔥 CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B) | CodeFuse-13B | 66k | 64 | 4096 | +## Models +We are excited to release the following two CodeLLMs trained by MFTCoder, now available on both HuggingFace and ModelScope: +| Model | HuggingFace Links | ModelScope Links | Base Model | Num of examples trained | Batch Size | Seq Length | +|----------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|-------------------------|------------|------------| +| 🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 600K | 80 | 4096 | +| 🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 600K | 80 | 4096 | +| 🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 600K | 80 | 4096 | +| 🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | +| 🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 600K | 80 | 4096 | +| 🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 1.1 Million | 256 | 4096 | +| 🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 1.1 Million | 256 | 4096 | +| 🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 700K | 128 | 4096 | + ## Datasets We are also pleased to release two code-related instruction datasets, meticulously selected from a range of datasets to facilitate multitask training. Moving forward, we are committed to releasing additional instruction datasets covering various code-related tasks. | Dataset | Description | |-----------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [⭐ Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) | Based on open-evol-instruction-80k, filter out low-quality, repeated, and similar instructions to HumanEval, thus get high-quality code instruction dataset. | +| [⭐ Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) | Based on open-evol-instruction-80k, filter out low-quality, repeated, and similar instructions to HumanEval, thus get high-quality code instruction dataset. | | [⭐ CodeExercise-Python-27k](https://huggingface.co/datasets/codefuse-ai/CodeExercise-Python-27k) | python code exercise instruction dataset | @@ -171,3 +207,13 @@ If you find our work useful or helpful for your R&D works, please feel free to c +## Join-US + +We are the AI Native team within the Platform Technology Business Group at Ant Group, dedicated to the intelligentization of Ant Group's platform engineering. Established for over three years, our team has played a pivotal role in supporting the intelligent operation and maintenance of Ant Group's cloud computing infrastructure. Our mission is to build algorithm services and platforms with a wide user base through world-class technological innovation and impact, supporting the implementation of internal and external products and businesses. +Embracing an innovation-driven ethos, our team not only supports business implementation but also propels technological influence. Over the past three years, we have published more than 20 papers at top conferences like ICLR, NeurIPS, KDD, and ACL. Our innovative business outcomes have earned us two Ant Technology's highest T-Star awards and one SuperMA award from Ant Group. Our open-source project CodeFuse has received 4K stars as of February 2024, and our models have been downloaded over 1.5 million times on Huggingface and Modelscope. + +**We are on the lookout for top talents to join our vibrant team! If you're eager to develop your career in an environment filled with energy, innovation, and a culture of excellence, we welcome you to explore our career opportunities for both campus and experienced hires. Join us and be a part of creating the next milestone in the industry.** + +**Campus Recruitment**: https://hrrecommend.antgroup.com/guide.html?code=8uoP5mlus5DqQYbE_EnqcE2FD5JZH21MwvMUIb9mb6X3osXPuBraG54SyM8GLn_7 + +**Experienced Hires**: https://talent.antgroup.com/off-campus-position?positionId=1933830 diff --git a/README_cn.md b/README_cn.md index 04a8aff..3102d9f 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,4 +1,4 @@ -# MFTCoder: 多任务大模型代码能力微调框架 +# MFTCoder: 高效准确的多任务大模型微调框架
@@ -22,6 +22,11 @@
+ 🤗 HuggingFace + • 🤖 魔搭 +
+ [**中文**] [[English]](README.md) @@ -36,9 +41,25 @@ - [训练](#训练) - [模型](#模型) - [数据集](#数据集) +- [加入我们](#加入我们) ## 新闻 +🔥🔥🔥 [2024/10/31] **MFTCoder-v0.5**发布,新增**xxpo**模块支持偏好对齐DPO/RPO/ORPO;新增**mpt**和**offline_tokenization**模块支持全量参数的加训;在原本的**pefts**模块(MFT)更新selfpaced收敛均衡技术并更名CoBa。 + +🔥🔥🔥 [2024/10/31] 我们的论文 [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) 已被 EMNLP 2024 接收,可以实现多任务收敛均衡。 + +🔥🔥🔥 [2024/05/20] **MFTCoder-v0.4**发布。新增支持**QLoRA+ DeepSpeed Zero3**, **QLoRA + FSDP**训练模式,可以更好的支持微调更大的模型,比如Qwen1.5-70B等。新增对Qwen2, Qwen2-MoE, Starcoder2, Gemma等模型的支持。 + +🔥🔥🔥 [2024/05/20] 我们的论文 [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) 已被 KDD 2024 接收. + +🔥🔥🔥 开源了[CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B)模型,在HumanEval上可以达到73.2%,多代码语言能力均衡. + +🔥🔥 [2024/01/17] **MFTCoder-v0.3.0**发布。新增对Mixtral(MoE), DeepSeek等模型的支持;新增支持FSDP(Fully Sharded Data Parallel);新增Self-paced Loss, 支持多任务收敛均衡。 感兴趣详见微信公众号CodeFuse的文章[MFTCoder 重磅升级v0.3.0发布](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) + +🔥🔥 [2024/01/17] 开源了[CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B)模型,在HumanEval pass@1(greedy decoding)上可以达到78.7%。该模型在Big Code榜单的结果近期发布,请关注公众号获取最新信息。 + +🔥🔥 [2024/01/17] 开源了[CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B)模型,在HumanEval pass@1(greedy decoding)上可以达到56.1%。感兴趣详见微信公众号CodeFuse的文章[MFTCoder提升Mixtral-8x7B混合专家模型的代码能力实践](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) 🔥🔥 [2023/11/07] [MFTCoder论文](https://arxiv.org/abs/2311.02303)在Arxiv公布,介绍了多任务微调的技术细节。 @@ -46,28 +67,31 @@ 🔥🔥 [2023/09/27] 开源了[CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B)模型,在HumanEval pass@1(greedy decoding)上可以达到54.9%。 -🔥🔥🔥 [2023/09/26] [CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits)量化版本发布,量化后模型在HumanEval pass@1指标为73.8% (贪婪解码)。 +🔥🔥 [2023/09/26] [CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits)量化版本发布,量化后模型在HumanEval pass@1指标为73.8% (贪婪解码)。 -🔥🔥🔥 [2023/09/07]MFTCoder微调的模型**CodeFuse-CodeLlama-34B**在[HumanEval Benchmarks](https://github.com/openai/human-eval)的Python **Pass@1** 取得了**74.4%**(greedy decoding)的开源SOTA成绩。 +🔥🔥 [2023/09/07]MFTCoder微调的模型**CodeFuse-CodeLlama-34B**在[HumanEval Benchmarks](https://github.com/openai/human-eval)的Python **Pass@1** 取得了**74.4%**(greedy decoding)的开源SOTA成绩。 -🔥 [2023/08/26]MFTCoder支持使用LoRA/QLoRA对Code Llama、Llama、Llama2、StarCoder、ChatGLM2、CodeGeeX2、Qwen和GPT-NeoX模型进行微调。 +🔥🔥 [2023/08/26]MFTCoder-v0.1.0 支持使用LoRA/QLoRA对Code Llama、Llama、Llama2、StarCoder、ChatGLM2、CodeGeeX2、Qwen和GPT-NeoX模型进行微调。 ### HumanEval表现 -| 模型 | HumanEval(Pass@1) | 日期 | -|:----------------------------|:-----------------:|:-------:| -| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | -|**CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | -| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | -| GPT-4(zero-shot) | 67.0% | 2023/03 | -| PanGu-Coder2 15B | 61.6% | 2023/08 | -| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | -| CodeLlama-34b-Python | 53.7% | 2023/08 | -| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | -| CodeLlama-34b | 48.8% | 2023/08 | -| GPT-3.5(zero-shot) | 48.1% | 2022/11 | -| OctoCoder | 46.2% | 2023/08 | -| StarCoder-15B | 33.6% | 2023/05 | -| QWen-14B | 32.3% | 2023/10 | +| 模型 | HumanEval(Pass@1) | 日期 | +|:---------------------------------|:-----------------:|:-------:| +| **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 | +| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | +| **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | +| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 | +| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | +| GPT-4(zero-shot) | 67.0% | 2023/03 | +| PanGu-Coder2 15B | 61.6% | 2023/08 | +| **CodeFuse-Mixtral-8x7B** | **56.1%** | 2024/01 | +| **CodeFuse-StarCoder-15B** | **54.9%** | 2023/08 | +| CodeLlama-34b-Python | 53.7% | 2023/08 | +| **CodeFuse-QWen-14B** | **48.8%** | 2023/10 | +| CodeLlama-34b | 48.8% | 2023/08 | +| GPT-3.5(zero-shot) | 48.1% | 2022/11 | +| OctoCoder | 46.2% | 2023/08 | +| StarCoder-15B | 33.6% | 2023/05 | +| QWen-14B | 32.3% | 2023/10 | ## 文章 @@ -82,53 +106,61 @@ **Codefuse-MFTCoder** 是一个开源的多任务代码大语言模型项目,包含代码大模型的模型、数据、训练等。我们希望通过开源,分享交流大语言模型在代码领域的进步。 ### 项目框架 - + ### 项目优势 :white_check_mark: **多任务**:一个模型同时支持多个任务,会保证多个任务之间的平衡,甚至可以泛化到新的没有见过的任务上去; :white_check_mark: **多模型**:支持最新的多个开源模型,包括gpt-neox,llama,llama-2,baichuan,Qwen,chatglm2等; -:white_check_mark: **多框架**:同时支持HuggingFace 和 [ATorch 框架](https://github.com/intelligent-machine-learning/dlrover); +:white_check_mark: **多框架**:既支持主流开源的Accelerate+DeepSpeed/FSDP,也支持新开源的[ATorch 框架](https://github.com/intelligent-machine-learning/dlrover); :white_check_mark: **高效微调**:支持LoRA和QLoRA,可以用很少的资源去微调很大的模型,且训练速度能满足几乎所有微调场景; 本项目主要内容如下: - 同时支持单任务SFT(Supervised FineTuning)和MFT(Multi-task FineTuning), 当前开源支持数据均衡,未来将持续开源难易均衡, 收敛均衡等 -- 支持QLoRA低成本高效指令微调、LoRA高效指令微调。 -- 支持绝大部分主流的开源大模型,重点关注代码能力优秀的开源大模型,如Qwen, GPT-Neox, Starcoder, Codegeex2, Code-LLaMA等。 +- 支持QLoRA低成本高效指令微调、LoRA高效指令微调、全量参数高精度微调。 +- 支持绝大部分主流的开源大模型,重点关注代码能力优秀的开源大模型,如DeepSeek-coder, Mistral, Mistral(MoE), Chatglm3, Qwen, GPT-Neox, Starcoder, Codegeex2, Code-LLaMA等。 - 支持lora与base model进行权重合并,推理更便捷。 - 整理并开源2个指令微调数据集:[Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k)和[CodeExercise-Python-27k](https://huggingface.co/datasets/codefuse-ai/CodeExercise-Python-27k)。 -- 开源2个[Codefuse系列指令微调模型权重]:[CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B)和[CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B)。 +- 开源多个[Codefuse系列指令微调模型权重],具体参见我们的huggingface组织和modelscope组织下的模型:[codefuse-ai huggingface](https://huggingface.co/codefuse-ai) or [codefuse-ai 魔搭](https://modelscope.cn/organization/codefuse-ai)。 ## 环境 -首先, 你需要将CUDA(>=11.4, 推荐11.7)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.0.0) +首先, 你需要将CUDA(>=11.4, 推荐12.1)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.1.0) 在requirements.txt下固定了几个主要的python包的版本,执行如下脚本即可: ```bash sh init_env.sh ``` -如果希望使用flash attention, 安装请参考 https://github.com/Dao-AILab/flash-attention +我们强烈建议您安装flash attention(>=2.3.0), 安装请参考 https://github.com/Dao-AILab/flash-attention ## 训练 -🚀 [Huggingface accelerate + deepspeed Codebase for MFT(Multi-task Finetuning)](./mft_peft_hf/README.md) +如果你熟悉大模型训练的各种主流开源资源,例如 ```transformers```, ```DeepSpeed```, ```FSDP```等, 为了用开源项目快速上手高性能微调,我们建议您尝试: + +🚀🚀 [MFTCoder-accelerate: Accelerate + DeepSpeed/FSDP Codebase for MFT(Multi-task Finetuning)](mftcoder_accelerate/README.md) + -🚀 [Atorch Codebase for MFT(Multi-task Finetuning)](./mft_atorch/README.md) +如果你想探索一些新兴的训练框架,可以尝试: + +🚀 [MFTCoder-atorch: Atorch Codebase for MFT(Multi-task Finetuning)](mftcoder_atorch/README.md) ## 模型 -使用本项目的训练代码,以及上述训练数据,我们训练并在huggingface开源了以下模型。 +使用本项目的训练代码,以及上述训练数据,我们训练并在huggingface, modelscope开源了以下模型。 -| 模型 | 基座模型 | 训练数据 | Batch Size | Seq Length | -|---------------------------------------------------------------|----------------------|------|------------|------------| -| [🔥🔥🔥 CodeFuse-CodeLlama-34B](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | -| [🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | -| [🔥🔥🔥 CodeFuse-StarCoder-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | -| [🔥🔥🔥 CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 | -| [🔥 CodeFuse-13B](https://huggingface.co/codefuse-ai/CodeFuse-13B) | CodeFuse-13B-Base | 6.6万 | 64 | 4096 | +| 模型 | HuggingFace链接 | 魔搭 链接 | 基座模型 | 训练数据 | Batch Size | Seq Length | +|--------------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|------|------------|------------| +| 🔥🔥🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | +| 🔥🔥🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 60万 | 80 | 4096 | +| 🔥🔥🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 | +| 🔥🔥🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 | +| 🔥🔥🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 70万 | 128 | 4096 | @@ -153,3 +185,15 @@ sh init_env.sh } ``` +## 加入我们 + +我们是平台技术事业群AI Native团队,负责蚂蚁蚂蚁集团平台工程的智能化,团队成立3年多以来,支持了蚂蚁集团云计算基础设施智能化运维的升级改造。团队的Mission是,通过世界级的技术创新和影响,构建有广泛用户的算法服务和平台,支撑内外部产品和业务落地。团队秉承创新基因,在支撑业务落地的同时,推动技术影响。3年以来在ICLR、NeurIPS、KDD、ACL等顶会发表论文20余篇,创新业务结果获得两次蚂蚁技术最高奖T-Star,1次蚂蚁集团最高奖SuperMA。开源项目CodeFuse获得4K点赞(2024年2月),Huggingface和modelscope上模型累积下载量超过150万次。 + +**我们正在寻找行业中的佼佼者加入我们的团队!如果您希望在一个充满活力、创新和卓越文化的环境中发展您的职业生涯,欢迎您查看我们的社招&校招机会,加入我们,一起创造下一个行业里程碑。** + +**校招**:https://hrrecommend.antgroup.com/guide.html?code=8uoP5mlus5DqQYbE_EnqcE2FD5JZH21MwvMUIb9mb6X3osXPuBraG54SyM8GLn_7 + +**社招**:https://talent.antgroup.com/off-campus-position?positionId=1933830 + +## 联系我们 + diff --git "a/assets/CodeFuse-AI\347\276\244.png" "b/assets/CodeFuse-AI\347\276\244.png" new file mode 100644 index 0000000..4e0c0a1 Binary files /dev/null and "b/assets/CodeFuse-AI\347\276\244.png" differ diff --git a/assets/img.jpg b/assets/img.jpg new file mode 100644 index 0000000..199cc8e Binary files /dev/null and b/assets/img.jpg differ diff --git a/assets/img.png b/assets/img.png deleted file mode 100644 index 0614694..0000000 Binary files a/assets/img.png and /dev/null differ diff --git a/assets/img_1.jpg b/assets/img_1.jpg new file mode 100644 index 0000000..bde7dac Binary files /dev/null and b/assets/img_1.jpg differ diff --git a/assets/img_1.png b/assets/img_1.png deleted file mode 100644 index 69ec9fd..0000000 Binary files a/assets/img_1.png and /dev/null differ diff --git a/init_env.sh b/init_env.sh index 7fdebf7..834b38d 100644 --- a/init_env.sh +++ b/init_env.sh @@ -1,4 +1,4 @@ -pip install torch==2.0.1 && \ -pip install tensorboard && \ +pip install torch==2.1.0 && \ +pip install tensorboard==2.11.0 && \ pip install packaging && \ -pip install -r requirements.txt \ No newline at end of file +pip install -r requirements.txt diff --git a/mft_peft_hf/README.md b/mft_peft_hf/README.md deleted file mode 100644 index 0d6705a..0000000 --- a/mft_peft_hf/README.md +++ /dev/null @@ -1,247 +0,0 @@ -# MFTCoder Training: Huggingface accelerate + DeepSpeed Framework -[](https://huggingface.co/codefuse-ai) - -= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, np_array): + assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype + bytes = self.out_file.write(np_array) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in np_array.shape: + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape)) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack("0: + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self, archive_name=None): + # TODO: streaming + cctx = zstandard.ZstdCompressor(level=3) + + if archive_name is None: + archive_name = str(int(time.time())) + + res = b"".join( + map(lambda x: ("%016d" % len(x)).encode("UTF-8") + x, map(lambda x: x.encode("UTF-8"), self.data)) + ) + cdata = cctx.compress(res) + + with open(self.out_dir + "/data_" + str(self.i) + "_" + archive_name + ".dat.zst", "wb") as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] + + +class JSONArchive: + def __init__(self, out_dir): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.data = [] + self.i = 0 + if os.path.exists(out_dir) and len(os.listdir(out_dir)) > 0: + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self): + cctx = zstandard.ZstdCompressor(level=3) + + cdata = cctx.compress(json.dumps(self.data).encode("UTF-8")) + with open(self.out_dir + "/data_" + str(self.i) + "_" + str(int(time.time())) + ".json.zst", "wb") as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] diff --git a/mft_peft_hf/src/data/gpt2_multi_task_dataset.py b/mftcoder_accelerate/src/data/multi_task_dataset.py similarity index 63% rename from mft_peft_hf/src/data/gpt2_multi_task_dataset.py rename to mftcoder_accelerate/src/data/multi_task_dataset.py index 4bb3c66..63c4b27 100644 --- a/mft_peft_hf/src/data/gpt2_multi_task_dataset.py +++ b/mftcoder_accelerate/src/data/multi_task_dataset.py @@ -2,16 +2,19 @@ # @author Chaoyu Chen # @date 2023/8/18 +Load dataset in a distributed way. """ + import os import json import math import time +import glob import numpy as np import torch from functools import partial -from data.tokenization.preprocess_data import UniformEncoder -from utils.common_utils import TASK2ID, ID2TASK +from data.preprocess_data import UniformEncoder +from utils.common_utils import TASK2ID class GPT2FromRawDataset(torch.utils.data.Dataset): @@ -27,12 +30,12 @@ def __init__( self.name = name self.input_dataset = input_dataset - self.num_samples = len(self.input_dataset['input_ids']) + self.num_samples = len(self.input_dataset["input_ids"]) self.seq_length = seq_length self.weighted_loss_mode = weighted_loss_mode self.ds_weight = ds_weight - self.task_name = data_prefix.split('/')[-1] + self.task_name = data_prefix.split("/")[-1] self.task_id = TASK2ID[self.task_name] # Checks @@ -47,8 +50,7 @@ def __getitem__(self, idx): try: # Get the shuffled index. idx = idx % self.num_samples - idx_data = {key: self.input_dataset[key][idx] - for key in self.input_dataset} + idx_data = {key: self.input_dataset[key][idx] for key in self.input_dataset} if self.weighted_loss_mode: idx_data["weight"] = np.array([self.ds_weight], dtype=np.float32) @@ -115,9 +117,7 @@ def __init__(self, datasets, weights, global_num_samples, local_num_samples): print( "> RANK {} elapsed time for building blendable dataset indices: " - "{:.2f} (sec)".format( - torch.distributed.get_rank(), time.time() - start_time - ) + "{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time) ) def calc_weights(self): @@ -146,32 +146,28 @@ def __getitem__(self, idx): def shuffle_arrays(arrays, set_seed=-1): - """Shuffles arrays in-place, in the same order, along axis=0 + """Shuffles arrays in-place, in the same order, along axis=0 - Parameters: - ----------- - arrays : List of NumPy arrays. - set_seed : Seed value if int >= 0, else seed is random. - """ - assert all(len(arr) == len(arrays[0]) for arr in arrays) - seed = np.random.randint(0, 2**(32 - 1) - 1) if set_seed < 0 else set_seed + Parameters: + ----------- + arrays : List of NumPy arrays. + set_seed : Seed value if int >= 0, else seed is random. + """ + assert all(len(arr) == len(arrays[0]) for arr in arrays) + seed = np.random.randint(0, 2 ** (32 - 1) - 1) if set_seed < 0 else set_seed - for arr in arrays: - rstate = np.random.RandomState(seed) - rstate.shuffle(arr) + for arr in arrays: + rstate = np.random.RandomState(seed) + rstate.shuffle(arr) def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, local_rank=0): - # tokenization编码器 encoder = UniformEncoder(args, args.tokenize_mode) encoder.initializer() - data_prefixes = list(args.data_paths[1:-1].split(',')) - - # data_weights = list(map(float, args.data_weights[1:-1].split(','))) - # print("data weights: ") - # print(data_weights) + data_prefixes = list(args.data_paths[1:-1].split(",")) + splits = [] splits_string = args.data_split if splits_string.find(",") != -1: @@ -183,7 +179,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, while len(splits) < 3: splits.append(0.0) splits = splits[:3] - print(f'data splits: {splits}') + print(f"data splits: {splits}") all_train_datasets = [] all_valid_datasets = [] @@ -199,46 +195,50 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, # 不同数据集在不同文件夹下 for dataset_index in range(len(data_prefixes)): - files = os.listdir(data_prefixes[dataset_index]) + # files = os.listdir(data_prefixes[dataset_index]) + # get all jsonl files and corresponding reading handler + if data_prefixes[dataset_index].endswith(".jsonl"): + files = [data_prefixes[dataset_index]] + else: + files = glob.glob(os.path.join(data_prefixes[dataset_index], "**/*.jsonl"), recursive=True) + cur_dataset_input_ids = [] cur_dataset_loss_mask = [] # support multiple jsonl files under task dir - for file in files: - file_name = data_prefixes[dataset_index] + '/' + file - if os.path.isdir(file_name): - continue - fin = open(file_name, 'r') - print(f'[Global Rank {global_rank}] open file {file_name}') - - if args.padding_mode == 'padding' or args.padding_mode == 'pack': + for file_name in files: + fin = open(file_name, "r") + print(f"[Global Rank {global_rank}] open file {file_name}") + + if args.padding_mode == "padding" or args.padding_mode == "pack" or args.padding_mode == "concat": for i, line in enumerate(fin): # pre-sharding if shard_data and i % world_size != global_rank: continue - data = json.loads(line.rstrip('\n\r')) - features, length = encoder.encode(data) + data = json.loads(line.rstrip("\n\r")) + features, length = encoder.encode(data, verbose=(i < 1)) + # features, length = encoder.encode(data) # may have more samples - for idx in range(len(features['input_ids'])): - cur_dataset_input_ids.append(features['input_ids'][idx]) - cur_dataset_loss_mask.append(features['loss_mask'][idx]) - + for idx in range(len(features["input_ids"])): + cur_dataset_input_ids.append(features["input_ids"][idx]) + cur_dataset_loss_mask.append(features["loss_mask"][idx]) + fin.close() else: i = 0 for line in fin: - data = json.loads(line.rstrip('\n\r')) + data = json.loads(line.rstrip("\n\r")) features, length = encoder.encode(data) # 一个document可能编码不出sample,可能编码出多个sample - for idx in range(len(features['input_ids'])): + for idx in range(len(features["input_ids"])): # post-sharding if shard_data and i % world_size != global_rank: i += 1 continue i += 1 - cur_dataset_input_ids.append(features['input_ids'][idx]) - cur_dataset_loss_mask.append(features['loss_mask'][idx]) + cur_dataset_input_ids.append(features["input_ids"][idx]) + cur_dataset_loss_mask.append(features["loss_mask"][idx]) fin.close() - + cur_dataset_input_ids = np.array(cur_dataset_input_ids, dtype=np.float32) cur_dataset_loss_mask = np.array(cur_dataset_loss_mask, dtype=np.float32) cur_dataset_num_tokens = np.sum(cur_dataset_loss_mask, dtype=np.int32) @@ -246,60 +246,57 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, num_tokens.append(cur_dataset_num_tokens) total_sample_cnt.append(cur_dataset_sample_num) effective_token_rate.append(cur_dataset_num_tokens / (cur_dataset_sample_num * args.seq_length)) - + # shuffle before split shuffle_arrays([cur_dataset_input_ids, cur_dataset_loss_mask], args.seed) train_ratio = splits[0] / 100.0 train_num = int(math.ceil(train_ratio * cur_dataset_sample_num)) # split train/valid - cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[: train_num], cur_dataset_input_ids[train_num: ] - cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[: train_num], cur_dataset_loss_mask[train_num: ] + cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[:train_num], cur_dataset_input_ids[train_num:] + cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[:train_num], cur_dataset_loss_mask[train_num:] local_train_num += train_num - local_valid_num += (cur_dataset_sample_num - train_num) - - cur_train_dataset = {'input_ids': cur_train_input_ids, - 'loss_mask': cur_train_loss_mask - } - cur_valid_dataset = {'input_ids': cur_valid_input_ids, - 'loss_mask': cur_valid_loss_mask - } + local_valid_num += cur_dataset_sample_num - train_num + + cur_train_dataset = {"input_ids": cur_train_input_ids, "loss_mask": cur_train_loss_mask} + cur_valid_dataset = {"input_ids": cur_valid_input_ids, "loss_mask": cur_valid_loss_mask} print(f"[Global Rank {global_rank}]shape of cur train dataset: {cur_train_dataset['input_ids'].shape}") - print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") + if local_valid_num > 0: + print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") cur_train_ds = GPT2FromRawDataset( - 'train', + "train", data_prefixes[dataset_index], cur_train_dataset, args.seq_length, weighted_loss_mode=args.weighted_loss_mode, - ds_weight=splits[0] - ) - cur_valid_ds = GPT2FromRawDataset( - 'valid', - data_prefixes[dataset_index], - cur_valid_dataset, - args.seq_length, - weighted_loss_mode=args.weighted_loss_mode, - ds_weight=splits[1] + ds_weight=splits[0], ) - all_train_datasets.append(cur_train_ds) - all_valid_datasets.append(cur_valid_ds) all_train_datasets_length.append(len(cur_train_ds)) - all_valid_datasets_length.append(len(cur_valid_ds)) - - print(f'[Global Rank {global_rank}]num tokens: {num_tokens}') - print(f'[Global Rank {global_rank}]effective token rate: {effective_token_rate}') + if local_valid_num > 0: + cur_valid_ds = GPT2FromRawDataset( + "valid", + data_prefixes[dataset_index], + cur_valid_dataset, + args.seq_length, + weighted_loss_mode=args.weighted_loss_mode, + ds_weight=splits[1], + ) + all_valid_datasets.append(cur_valid_ds) + all_valid_datasets_length.append(len(cur_valid_ds)) + else: + cur_valid_ds = None + + print(f"[Global Rank {global_rank}]num tokens: {num_tokens}") + print(f"[Global Rank {global_rank}]effective token rate: {effective_token_rate}") num_tokens = [] ds_fn = partial(ds_weights_by_num_docs_sft) - train_loss_weights, valid_loss_weights = ( - ds_fn(all_train_datasets_length), - ds_fn(all_valid_datasets_length), - ) - + train_loss_weights = ds_fn(all_train_datasets_length) print(f"> train loss weights in rank {global_rank}: {train_loss_weights}") - print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") + if all_valid_datasets_length: + valid_loss_weights = ds_fn(all_valid_datasets_length) + print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") factor = 1 # calcualte common factor based on token cnt and total sample cnt @@ -307,51 +304,65 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length) factor /= sum([1.0 / w for w in train_loss_weights]) / len(train_loss_weights) print(f"> common denomination factor for CE loss in rank {global_rank}: {factor}") - + train_sample_weights = [x / sum(all_train_datasets_length) for x in all_train_datasets_length] - valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] print(f"> train sample weights in rank {global_rank}: {train_sample_weights}") - print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") + if all_valid_datasets_length: + valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] + print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") # recompute global_train_num and global_valid_num - + torch.distributed.barrier() device = f"cuda:{local_rank}" - + global_train_num_samples_tensor = torch.tensor(local_train_num, dtype=torch.int32) global_train_num_samples_tensor = global_train_num_samples_tensor.to(device) torch.distributed.all_reduce(global_train_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) global_train_num = global_train_num_samples_tensor.item() - - global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) - global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) - torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) - global_valid_num = global_valid_num_samples_tensor.item() print(f"> global train num in rank {global_rank}: {global_train_num}") - print(f"> global valid num in rank {global_rank}: {global_valid_num}") - + + if local_valid_num > 0: + global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) + global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) + torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) + global_valid_num = global_valid_num_samples_tensor.item() + print(f"> global valid num in rank {global_rank}: {global_valid_num}") + torch.distributed.barrier() - for i in range(len(all_train_datasets)): - print(f'loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}') blending_train_dataset = None if all_train_datasets: + for i in range(len(all_train_datasets)): + print( + f"loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}" + ) args.do_train = True for i in range(len(all_train_datasets)): all_train_datasets[i].update_ds_weight(train_loss_weights[i] / factor) - print(f'loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}') - blending_train_dataset = GPT2BlendableDataset(all_train_datasets, train_sample_weights, global_train_num, local_train_num) - - for i in range(len(all_train_datasets)): - print(f'loss weight of valid dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}') + print( + f"loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}" + ) + blending_train_dataset = GPT2BlendableDataset( + all_train_datasets, train_sample_weights, global_train_num, local_train_num + ) + blending_valid_dataset = None if all_valid_datasets: + for i in range(len(all_valid_datasets)): + print( + f"loss weight of valid dataset {i} before update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" + ) args.do_valid = True for i in range(len(all_valid_datasets)): all_valid_datasets[i].update_ds_weight(valid_loss_weights[i] / factor) - print(f'loss weight of valid dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}') - blending_valid_dataset = GPT2BlendableDataset(all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num) - + print( + f"loss weight of valid dataset {i} after update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" + ) + blending_valid_dataset = GPT2BlendableDataset( + all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num + ) + return blending_train_dataset, blending_valid_dataset @@ -360,9 +371,13 @@ def compile_helper(): is invoked on a single process.""" import os import subprocess + path = os.path.abspath(os.path.dirname(__file__)) ret = subprocess.run(["make", "-C", path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys + sys.exit(1) + else: + print("Making C++ dataset helpers module successfully.") diff --git a/mftcoder_accelerate/src/data/preprocess_data.py b/mftcoder_accelerate/src/data/preprocess_data.py new file mode 100644 index 0000000..f7226bd --- /dev/null +++ b/mftcoder_accelerate/src/data/preprocess_data.py @@ -0,0 +1,356 @@ +""" +# @author Chaoyu Chen +# @date 2023/9/13 +Preprocessing data and tokenization. +""" + +import os +import sys +import ftfy +import glob + +# print("In preprocess_data_new.py, sys path:", sys.path) + +from tokenizer import build_tokenizer + +CHAT_COL = "chat_rounds" +ROLE_COL = "role" +CONTENT_COL = "content" + +SYSTEM_COL = "system" +PROMPT_COL = "prompt" +ANSWER_COL = "answer" + +TEXT_COL = "text" + +table = {ord(f): ord(t) for f, t in zip(",。!?:【】()%#@&1234567890", ",.!?:[]()%#@&1234567890")} + + +def content_format(content: str): + # Replace non-breaking space with space + content = content.replace("\u202f", " ").replace("\xa0", " ") + + # change chinese punctuation to english ones + # text = text.translate(table) + # if not content.endswith("\n"): + content += "\n" + + return content + + +def is_text_format(data): + if "text" in data: + return True + else: + return False + + +def is_chatml_format(data): + if "chat_rounds" in data and len(data["chat_rounds"]) > 0: + return True + else: + return False + + +def is_prompt_answer_format(data): + if "prompt" in data and "answer" in data: + return True + else: + return False + + +def is_prompt_response_format(data): + if "prompt" in data and "response" in data: + return True + else: + return False + + +def is_input_output_format(data): + if "input" in data and "output" in data: + return True + else: + return False + + +def is_instruction_output_format(data): + if "instruction" in data and "output" in data: + return True + else: + return False + + +def is_instruction_response_format(data): + if "instruction" in data and "response" in data: + return True + else: + return False + + +def is_question_response_format(data): + if "question" in data and "response" in data: + return True + else: + return False + + +def is_question_answer_format(data): + if "question" in data and "answer" in data: + return True + else: + return False + + +def is_query_answer_format(data): + if "query" in data and "answer" in data: + return True + else: + return False + + +class Encoder(object): + tokenizer = None + + def __init__(self, args): + self.args = args + + def initializer(self): + # Use Encoder class as a container for global data + Encoder.tokenizer = build_tokenizer(self.args) + # self.tokenizer = build_tokenizer(self.args) + + def pure_encode(self, content): + return Encoder.tokenizer.encode(content, add_special_tokens=False) + + def encode(self, text): + if self.args.ftfy: + text = ftfy.fix_text(text) + ids = {} + for key in self.args.jsonl_keys: + doc_ids = [] + text_ids = self.pure_encode(text) + if len(text_ids) > 0: + doc_ids.append(text_ids) + if self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eos_token_id) + ids[key] = doc_ids + return ids, len(text) + + +class UniformEncoder(Encoder): + def __init__(self, args, mode="sft"): + super().__init__(args) + self.verbose = False + self.mode = mode + # seq_length + 1 for shifting + if args.load_raw_dataset: + self.seq_length = args.seq_length + 1 + self.stride = args.seq_length + else: + self.seq_length = args.seq_length + + self.remain_input_ids = [] + self.remain_loss_mask = [] + + def encode(self, data, verbose=False): + self.verbose = verbose + encode_res = {"input_ids": [], "loss_mask": []} + + if is_prompt_answer_format(data): + data_type = "prompt_answer" + elif is_prompt_response_format(data): + data_type = "prompt_response" + elif is_input_output_format(data): + data_type = "input_output" + elif is_instruction_output_format(data): + data_type = "instruction_output" + elif is_instruction_response_format(data): + data_type = "instruction_response" + elif is_question_response_format(data): + data_type = "question_response" + elif is_question_answer_format(data): + data_type = "question_answer" + elif is_query_answer_format(data): + data_type = "query_answer" + elif is_chatml_format(data): + data_type = "chatML" + elif is_text_format(data): + data_type = "text" + else: + raise ValueError( + f"data_type does not support" + f"please use chatML or prompt/answer, prompt/response, question/response, " + f"instruction/output, input/output, instruction/output or text(only for pretrain)" + ) + + length = 0 + if data_type == "chatML": + for chat in data["chat_rounds"]: + length += len(chat["content"]) + elif data_type == "text": + length += len(data["text"]) + else: + # update key + global PROMPT_COL, ANSWER_COL + PROMPT_COL, ANSWER_COL = tuple(data_type.split("_")) + length = len(data[PROMPT_COL]) + len(data[ANSWER_COL]) + + for token_res in self._tokenize_fields(data, data_type=data_type): + for k, v in token_res.items(): + encode_res[k].append(v) + + return encode_res, length + + def _tokenize_fields(self, data, data_type): + if self.mode == "sft": + if self.args.role_markers: + system_marker = self.args.role_markers["system"] + user_marker = self.args.role_markers["user"] + assistant_marker = self.args.role_markers["assistant"] + else: + system_marker = "system\n" + user_marker = "" + + @property + def eos_token_id(self): + return self.get_command("human\n" + assistant_marker = "") + out = out[: special_index] + token_length = len(tokenizer.encode_plus(out)["input_ids"]) + convert_tokens = convert_tokens[:token_length] + probs = probs[:token_length] + + if len(out) > 0 and out[0] == " ": + out = out[1:] + + convert_tokens = convert_tokens[1:] + probs = probs[1:] + return out @add_start_docstrings( """ The LLaMa Model transformer with a sequence classification head on top (linear layer). - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + [`AquilaForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a @@ -913,13 +1047,16 @@ def _reorder_cache(past_key_values, beam_idx): padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - LLAMA_START_DOCSTRING, + AQUILA_START_DOCSTRING, ) -class LlamaForSequenceClassification(LlamaPreTrainedModel): +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->AQUILA,Llama->Aquila +class AquilaForSequenceClassification(AquilaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = LlamaModel(config) + self.model = AquilaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing @@ -931,7 +1068,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -978,7 +1115,9 @@ def forward( sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( + logits.device + ) else: sequence_lengths = -1 @@ -1018,3 +1157,4 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + diff --git a/mft_peft_hf/src/model/baichuan/configuration_baichuan.py b/mftcoder_accelerate/src/model/baichuan2/configuration_baichuan.py similarity index 91% rename from mft_peft_hf/src/model/baichuan/configuration_baichuan.py rename to mftcoder_accelerate/src/model/baichuan2/configuration_baichuan.py index 9da5dfc..592d274 100644 --- a/mft_peft_hf/src/model/baichuan/configuration_baichuan.py +++ b/mftcoder_accelerate/src/model/baichuan2/configuration_baichuan.py @@ -2,6 +2,7 @@ from transformers.configuration_utils import PretrainedConfig + class BaichuanConfig(PretrainedConfig): model_type = "baichuan" keys_to_ignore_at_inference = ["past_key_values"] @@ -23,7 +24,7 @@ def __init__( eos_token_id=2, tie_word_embeddings=False, gradient_checkpointing=False, - use_xformers=False, + z_loss_weight=0, **kwargs, ): self.vocab_size = vocab_size @@ -36,8 +37,8 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache - self.gradient_checkpointing = gradient_checkpointing - self.use_xformers = use_xformers + self.z_loss_weight = z_loss_weight + self.gradient_checkpointing = (gradient_checkpointing,) super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -45,4 +46,3 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) - diff --git a/mftcoder_accelerate/src/model/baichuan2/generation_utils.py b/mftcoder_accelerate/src/model/baichuan2/generation_utils.py new file mode 100644 index 0000000..5771699 --- /dev/null +++ b/mftcoder_accelerate/src/model/baichuan2/generation_utils.py @@ -0,0 +1,83 @@ +from typing import List +from queue import Queue + +import torch + + +def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0): + def _parse_messages(messages, split_role="user"): + system, rounds = "", [] + round = [] + for i, message in enumerate(messages): + if message["role"] == "system": + assert i == 0 + system = message["content"] + continue + if message["role"] == split_role and round: + rounds.append(round) + round = [] + round.append(message) + if round: + rounds.append(round) + return system, rounds + + max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens + max_input_tokens = model.config.model_max_length - max_new_tokens + system, rounds = _parse_messages(messages, split_role="user") + system_tokens = tokenizer.encode(system) + max_history_tokens = max_input_tokens - len(system_tokens) + + history_tokens = [] + for round in rounds[::-1]: + round_tokens = [] + for message in round: + if message["role"] == "user": + round_tokens.append(model.generation_config.user_token_id) + else: + round_tokens.append(model.generation_config.assistant_token_id) + round_tokens.extend(tokenizer.encode(message["content"])) + if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: + history_tokens = round_tokens + history_tokens # concat left + if len(history_tokens) < max_history_tokens: + continue + break + + input_tokens = system_tokens + history_tokens + if messages[-1]["role"] != "assistant": + input_tokens.append(model.generation_config.assistant_token_id) + input_tokens = input_tokens[-max_input_tokens:] # truncate left + return torch.LongTensor([input_tokens]).to(model.device) + + +class TextIterStreamer: + def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt + self.skip_special_tokens = skip_special_tokens + self.tokens = [] + self.text_queue = Queue() + self.next_tokens_are_prompt = True + + def put(self, value): + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + else: + if len(value.shape) > 1: + value = value[0] + self.tokens.extend(value.tolist()) + self.text_queue.put( + self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) + + def end(self): + self.text_queue.put(None) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get() + if value is None: + raise StopIteration() + else: + return value + diff --git a/mftcoder_accelerate/src/model/baichuan2/modeling_baichuan.py b/mftcoder_accelerate/src/model/baichuan2/modeling_baichuan.py new file mode 100644 index 0000000..9f9875b --- /dev/null +++ b/mftcoder_accelerate/src/model/baichuan2/modeling_baichuan.py @@ -0,0 +1,827 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +from .configuration_baichuan import BaichuanConfig +from .generation_utils import build_chat_input, TextIterStreamer + +import math +from threading import Thread +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F +from transformers import PreTrainedModel, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.generation.utils import GenerationConfig +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging, ContextManagers + +import os +from contextlib import contextmanager +from accelerate import init_empty_weights + +logger = logging.get_logger(__name__) + +try: + from xformers import ops as xops +except ImportError: + xops = None + logger.warning( + "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers." + ) + + +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def _fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): + _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1) + _future_mask = _future_mask.unsqueeze(0) + alibi + new_future_mask = _future_mask.to(tensor) + return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] + + +def _gen_alibi_mask(tensor, n_head, max_pos): + slopes = torch.Tensor(_get_interleave(n_head)) + position_point = torch.arange(max_pos) - max_pos + 1 + position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + alibi = alibi.view(n_head, 1, max_pos) + alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) + alibi_mask = alibi_mask.unsqueeze(0) + alibi + return alibi_mask + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, epsilon=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(hidden_size)) + self.epsilon = epsilon + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + + # convert into half-precision + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class MLP(torch.nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class BaichuanAttention(torch.nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.model_max_length + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" + ) + self.W_pack = torch.nn.Linear( + self.hidden_size, 3 * self.hidden_size, bias=False + ) + self.o_proj = torch.nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = ( + proj.unflatten(-1, (3, self.hidden_size)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + ) + query_states = ( + proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + key_states = ( + proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + value_states = ( + proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + if xops is not None: + attn_weights = None + # query_states = query_states.transpose(1, 2) + # key_states = key_states.transpose(1, 2) + # value_states = value_states.transpose(1, 2) + # attn_output = xops.memory_efficient_attention( + # query_states, key_states, value_states, attn_bias=attention_mask + # ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask) + attn_output = attn_output.transpose(1, 2) + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BaichuanLayer(torch.nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = BaichuanAttention(config=config) + self.mlp = MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, epsilon=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BaichuanPreTrainedModel(PreTrainedModel): + config_class = BaichuanConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BaichuanLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, torch.nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, torch.nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BaichuanModel): + module.gradient_checkpointing = value + + +class BaichuanModel(BaichuanPreTrainedModel): + def __init__(self, config: BaichuanConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.n_head = config.num_attention_heads + self.embed_tokens = torch.nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = torch.nn.ModuleList( + [BaichuanLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + + self.gradient_checkpointing = config.gradient_checkpointing + self.post_init() + self.max_cache_pos = config.model_max_length + self.first_run = True + self.alibi_mask = None + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_alibi_mask(self, tensor, seq_length_with_past): + if self.training: + slopes = torch.Tensor(_get_interleave(self.n_head)) + position_point = ( + torch.arange(seq_length_with_past) - seq_length_with_past + 1 + ) + position_point = ( + position_point.unsqueeze(0) + .unsqueeze(0) + .expand(self.n_head, seq_length_with_past, -1) + ) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( + -1, -2 + ) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + mask = _buffered_future_mask( + tensor, seq_length_with_past, alibi, self.n_head + ) + else: + if self.first_run: + self.first_run = False + self.register_buffer( + "future_mask", + _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( + tensor + ), + persistent=False, + ) + if seq_length_with_past > self.max_cache_pos: + self.max_cache_pos = seq_length_with_past + self.register_buffer( + "future_mask", + _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( + tensor + ), + persistent=False, + ) + mask = self.future_mask[ + : self.n_head, :seq_length_with_past, :seq_length_with_past + ] + return mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot provide both input_ids and inputs_embeds simultaneously" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You need to provide input_ids or inputs_embeds") + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + seq_length_with_past = seq_length + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.training: + if ( + self.alibi_mask is None + or self.alibi_mask.shape[-1] != seq_length_with_past + ): + self.alibi_mask = self.get_alibi_mask( + inputs_embeds, seq_length_with_past + ) + alibi_mask = self.alibi_mask + else: + alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) + + if attention_mask is not None: + if len(attention_mask.shape) == 2: + expanded_mask = attention_mask.to(alibi_mask.dtype) + expanded_mask = torch.tril( + torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) + ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) + else: + expanded_mask = attention_mask + bsz = inputs_embeds.size(0) + src_len, tgt_len = alibi_mask.size()[-2:] + expanded_mask = ( + expanded_mask.unsqueeze(1) + .expand(bsz, 1, src_len, tgt_len) + .to(alibi_mask.dtype) + ) + inverted_mask = 1.0 - expanded_mask + inverted_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min + ) + attention_mask = inverted_mask + alibi_mask.unsqueeze(0) + else: + attention_mask = alibi_mask + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class NormHead(nn.Module): + def __init__(self, hidden_size, vocab_size, bias=False): + super().__init__() + self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.first_flag = True + + def forward(self, hidden_states): + if self.training: + norm_weight = nn.functional.normalize(self.weight) + self.first_flag = True + elif self.first_flag: + self.first_flag = False + self.weight.data = nn.functional.normalize(self.weight) + norm_weight = self.weight + else: + norm_weight = self.weight + return nn.functional.linear(hidden_states, norm_weight) + +_init_weights = True +@contextmanager +def no_init_weights(_enable=True): + global _init_weights + old_init_weights = _init_weights + if _enable: + _init_weights = False + try: + yield + finally: + _init_weights = old_init_weights + + +class BaichuanForCausalLM(BaichuanPreTrainedModel): + def __init__(self, config, *model_args, **model_kwargs): + super().__init__(config, *model_args, **model_kwargs) + self.model = BaichuanModel(config) + self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False) + #if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: + if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False): + try: + from .quantizer import quantize_offline, init_model_weight_int4 + except ImportError: + raise ImportError(f"Needs quantize_offline to run quantize.") + quantize_offline(self, 4) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=False, + proxies=None, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder="", + _from_auto=False, + _from_pipeline=None, + **kwargs, + ) + else: + model_kwargs = kwargs + + if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: + try: + from .quantizer import init_model_weight_int4 + from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map + from accelerate.utils import CustomDtype + from accelerate.utils import get_balanced_memory + except ImportError: + raise ImportError(f"Needs import model weight init func to run quantize.") + # Instantiate model. + init_contexts = [no_init_weights(_enable=True)] + init_contexts.append(init_empty_weights()) + with ContextManagers(init_contexts): + model = cls(config) + + model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin') + state_dict = torch.load(model_file, map_location="cpu") + model.is_quantized = True + + device_map = kwargs.pop("device_map", None) + torch_dtype = kwargs.pop("torch_dtype", None) + if device_map is not None: + kwargs = {"no_split_module_classes": model._no_split_modules} + target_dtype = CustomDtype.INT4 + max_memory = get_balanced_memory( + model, + dtype=target_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=None, + **kwargs, + ) + kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs) + model = init_model_weight_int4(config, model, state_dict) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=False, + proxies=None, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder="", + _from_auto=False, + _from_pipeline=None, + **kwargs, + ) + except (OSError, TypeError): + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if device_map is not None: + dispatch_model(model, device_map=device_map) + + return model + + return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args, + config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, + use_safetensors=use_safetensors, **kwargs) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + softmax_normalizer = shift_logits.max(-1).values ** 2 + z_loss = self.config.z_loss_weight * softmax_normalizer.mean() + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + z_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def quantize(self, bits: int): + try: + from .quantizer import quantize_online + except ImportError: + raise ImportError(f"Needs QLinear to run quantize.") + return quantize_online(self, bits) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + return tuple( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) + for layer_past in past_key_values + ) + + def _build_chat_input( + self, tokenizer, messages: List[dict], max_new_tokens: int = 0 + ): + max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens + max_input_tokens = self.config.model_max_length - max_new_tokens + max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) + total_input, round_input = [], [] + for i, message in enumerate(messages[::-1]): + content_tokens = tokenizer.encode(message["content"]) + if message["role"] == "user": + round_input = ( + [self.generation_config.user_token_id] + + content_tokens + + round_input + ) + if ( + total_input + and len(total_input) + len(round_input) > max_input_tokens + ): + break + else: + total_input = round_input + total_input + if len(total_input) >= max_input_tokens: + break + else: + round_input = [] + elif message["role"] == "assistant": + round_input = ( + [self.generation_config.assistant_token_id] + + content_tokens + + [self.generation_config.eos_token_id] + + round_input + ) + else: + raise ValueError(f"message role not supported yet: {message['role']}") + total_input = total_input[-max_input_tokens:] # truncate left + total_input.append(self.generation_config.assistant_token_id) + total_input = torch.LongTensor([total_input]).to(self.device) + return total_input + + def chat(self, tokenizer, messages: List[dict], stream=False, + generation_config: Optional[GenerationConfig]=None): + generation_config = generation_config or self.generation_config + input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens) + if stream: + streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + Thread(target=self.generate, kwargs=dict( + inputs=input_ids, streamer=streamer, + generation_config=generation_config, + )).start() + return streamer + else: + outputs = self.generate(input_ids, generation_config=generation_config) + response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) + return response diff --git a/mftcoder_accelerate/src/model/baichuan2/quantizer.py b/mftcoder_accelerate/src/model/baichuan2/quantizer.py new file mode 100644 index 0000000..232fc74 --- /dev/null +++ b/mftcoder_accelerate/src/model/baichuan2/quantizer.py @@ -0,0 +1,215 @@ +try: + import bitsandbytes as bnb + from bitsandbytes.nn.modules import Params4bit, Int8Params +except ImportError: + print('import bitsandbytes Error') + +from accelerate import init_empty_weights +import torch + +def Params4bitCuda(self, device): + self.data = self.data.cuda(device) + self.quant_state[0] = self.quant_state[0].cuda(device) + self.quant_state[4][0] = self.quant_state[4][0].cuda(device) + self.quant_state[4][1][0] = self.quant_state[4][1][0].cuda(device) + self.quant_state[4][1][1] = self.quant_state[4][1][1].cuda(device) + + self.quant_state[6] = self.quant_state[6].cuda(device) + return self + +class Linear4bitOnline(torch.nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type + ) + self.compute_dtype = None + #self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out + +class Linear8bitLtOnline(torch.nn.Module): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__() + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) + self.bias = bias + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out + +def quantize_offline(model, bits: int): + assert (bits == 4), f'bits: {bits} is not supported' + + for i, layer in enumerate(model.model.layers): + layer.self_attn.W_pack = bnb.nn.Linear4bit( + layer.self_attn.W_pack.weight.shape[1], + layer.self_attn.W_pack.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + layer.self_attn.o_proj = bnb.nn.Linear4bit( + layer.self_attn.o_proj.weight.shape[1], + layer.self_attn.o_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + + layer.mlp.gate_proj = bnb.nn.Linear4bit( + layer.mlp.gate_proj.weight.shape[1], + layer.mlp.gate_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + layer.mlp.down_proj = bnb.nn.Linear4bit( + layer.mlp.down_proj.weight.shape[1], + layer.mlp.down_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + layer.mlp.up_proj = bnb.nn.Linear4bit( + layer.mlp.up_proj.weight.shape[1], + layer.mlp.up_proj.weight.shape[0], + False, + torch.float16, + compress_statistics=True, + quant_type="nf4", + ) + return model + +def quantize_online(model, bits: int): + def quant(weight, bias=None): + if bits == 8: + linear = Linear8bitLtOnline( + weight, + bias, + has_fp16_weights=False, + threshold=6.0, + ) + if bias is not None: + linear.bias = torch.nn.Parameter(bias) + elif bits == 4: + linear = Linear4bitOnline( + weight, + bias, + quant_type="nf4", #fp4/nf4 + ) + else: + raise ValueError("quantize only support 4/8 bit") + return linear + + for i, layer in enumerate(model.model.layers): + layer.self_attn.W_pack = quant(layer.self_attn.W_pack.weight) + layer.self_attn.o_proj = quant(layer.self_attn.o_proj.weight) + layer.mlp.gate_proj = quant(layer.mlp.gate_proj.weight) + layer.mlp.down_proj = quant(layer.mlp.down_proj.weight) + layer.mlp.up_proj = quant(layer.mlp.up_proj.weight) + return model + +def init_model_weight_int4(config, model, state_dict): + #replace Params4bit.cuda with Params4bitCuda + Params4bit.cuda = Params4bitCuda + + for i in range(config.num_hidden_layers): + weight_data = state_dict[f'model.layers.{i}.self_attn.W_pack.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.self_attn.W_pack.weight.quant_state'] + model.model.layers[i].self_attn.W_pack.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.self_attn.o_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.self_attn.o_proj.weight.quant_state'] + model.model.layers[i].self_attn.o_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.mlp.gate_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.mlp.gate_proj.weight.quant_state'] + model.model.layers[i].mlp.gate_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.mlp.up_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.mlp.up_proj.weight.quant_state'] + model.model.layers[i].mlp.up_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + weight_data = state_dict[f'model.layers.{i}.mlp.down_proj.weight.data'] + weight_quant_state = state_dict[f'model.layers.{i}.mlp.down_proj.weight.quant_state'] + model.model.layers[i].mlp.down_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state) + + model.model.layers[i].input_layernorm.weight = state_dict[f'model.layers.{i}.input_layernorm.weight'] + model.model.layers[i].post_attention_layernorm.weight = state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] + + model.model.embed_tokens.weight = state_dict['model.embed_tokens.weight'] + model.model.norm.weight = state_dict['model.norm.weight'] + model.lm_head.weight = state_dict['lm_head.weight'] + return model \ No newline at end of file diff --git a/mft_peft_hf/src/model/baichuan/tokenization_baichuan.py b/mftcoder_accelerate/src/model/baichuan2/tokenization_baichuan.py similarity index 85% rename from mft_peft_hf/src/model/baichuan/tokenization_baichuan.py rename to mftcoder_accelerate/src/model/baichuan2/tokenization_baichuan.py index 1275eb0..978963e 100644 --- a/mft_peft_hf/src/model/baichuan/tokenization_baichuan.py +++ b/mftcoder_accelerate/src/model/baichuan2/tokenization_baichuan.py @@ -48,10 +48,26 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False) + if isinstance(pad_token, str) + else pad_token + ) super().__init__( bos_token=bos_token, eos_token=eos_token, @@ -122,7 +138,9 @@ def convert_tokens_to_string(self, tokens): out_string += self.sp_model.decode(current_sub_tokens) return out_string - def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. @@ -137,10 +155,14 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) logger.error(f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -161,7 +183,10 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding @@ -180,7 +205,9 @@ def get_special_tokens_mask( """ if already_has_special_tokens: return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, ) bos_token_id = [1] if self.add_bos_token else [] @@ -229,4 +256,3 @@ def create_token_type_ids_from_sequences( output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) return output - diff --git a/mft_peft_hf/src/model/chatglm2/config.json b/mftcoder_accelerate/src/model/chatglm2/config.json similarity index 100% rename from mft_peft_hf/src/model/chatglm2/config.json rename to mftcoder_accelerate/src/model/chatglm2/config.json diff --git a/mft_peft_hf/src/model/chatglm2/configuration_chatglm.py b/mftcoder_accelerate/src/model/chatglm2/configuration_chatglm.py similarity index 100% rename from mft_peft_hf/src/model/chatglm2/configuration_chatglm.py rename to mftcoder_accelerate/src/model/chatglm2/configuration_chatglm.py diff --git a/mft_peft_hf/src/model/chatglm2/modeling_chatglm.py b/mftcoder_accelerate/src/model/chatglm2/modeling_chatglm.py similarity index 100% rename from mft_peft_hf/src/model/chatglm2/modeling_chatglm.py rename to mftcoder_accelerate/src/model/chatglm2/modeling_chatglm.py diff --git a/mft_peft_hf/src/model/chatglm2/quantization.py b/mftcoder_accelerate/src/model/chatglm2/quantization.py similarity index 100% rename from mft_peft_hf/src/model/chatglm2/quantization.py rename to mftcoder_accelerate/src/model/chatglm2/quantization.py diff --git a/mft_peft_hf/src/model/chatglm2/tokenization_chatglm.py b/mftcoder_accelerate/src/model/chatglm2/tokenization_chatglm.py similarity index 99% rename from mft_peft_hf/src/model/chatglm2/tokenization_chatglm.py rename to mftcoder_accelerate/src/model/chatglm2/tokenization_chatglm.py index d4ce416..01d8278 100644 --- a/mft_peft_hf/src/model/chatglm2/tokenization_chatglm.py +++ b/mftcoder_accelerate/src/model/chatglm2/tokenization_chatglm.py @@ -66,7 +66,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs): - super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) + self.name = "GLMTokenizer" self.vocab_file = vocab_file @@ -76,6 +76,7 @@ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces "bot\n" + elif self.mode == "pretrain": + system_marker = "" + user_marker = "" + assistant_marker = "" + else: + raise ValueError(f"tokenize_mode does not support {self.mode}, please use sft or pretrain") + + sft_end_marker_ids = [Encoder.tokenizer.eos_token_id] + # uniform SST,SFT,MFT + input_ids = [] + loss_mask = [] + + if data_type == "chatML": + chat = data[CHAT_COL] + if chat[0][ROLE_COL] == "system": + sys_content_ids = self.pure_encode(system_marker + content_format(chat[0][CONTENT_COL])) + chat = chat[1:] + input_ids += sys_content_ids + loss_mask += [0] * len(sys_content_ids) + + for i, r in enumerate(chat): + role = r[ROLE_COL] + content = r[CONTENT_COL] + content = content_format(content) + if (role == "human" or role == "user") != (i % 2 == 0): + raise ValueError( + "Conversation roles must alternate user/assistant/user/assistant/... or human/bot/human/bot/...')" + ) + + # compute loss only for assistant's content and eos token afterward + if role == "human" or role == "user": + content_ids = self.pure_encode(user_marker + content + assistant_marker) + input_ids += content_ids + loss_mask += [0] * len(content_ids) + elif role == "bot" or role == "assistant" or role == "gpt": + content_ids = self.pure_encode(content) + sft_end_marker_ids + input_ids += content_ids + loss_mask += [1] * len(content_ids) + extra_ids = self.pure_encode("\n") + input_ids += extra_ids + loss_mask += [0] * len(extra_ids) + else: + raise ValueError(f"Role {role} not supported.") + + elif data_type == "text": + text = data[TEXT_COL] + text = content_format(text) + text_ids = self.pure_encode(text) + sft_end_marker_ids + input_ids += text_ids + loss_mask += [1] * len(text_ids) + else: + system = data.get(SYSTEM_COL, "") + prompt = data[PROMPT_COL] + answer = data[ANSWER_COL] + + system = content_format(system_marker + system) if system else "" + prompt = content_format(prompt) + answer = content_format(answer) + + prompt_ids = self.pure_encode(system + user_marker + prompt + assistant_marker) + answer_ids = self.pure_encode(answer) + sft_end_marker_ids + + input_ids += prompt_ids + answer_ids + loss_mask += [0] * len(prompt_ids) + [1] * len(answer_ids) + + # print(self.mode) + if self.mode == "pretrain": + # change loss mask to all 1s + input_ids = input_ids + loss_mask = [1] * len(loss_mask) + elif self.mode == "sft": + # do nothing + input_ids = input_ids + loss_mask = loss_mask + + if self.verbose: + print(f"original data:\n{data}") + print(f"decoding back:\n{Encoder.tokenizer.decode(input_ids)}") + + assert len(input_ids) == len(loss_mask) + if self.args.padding_mode == "padding": + if len(input_ids) <= self.seq_length: + yield self.padding(input_ids, loss_mask) + + # drop if too long + else: + yield {} + elif self.args.padding_mode == "concat": + input_ids = self.remain_input_ids + input_ids + loss_mask = self.remain_loss_mask + loss_mask + if len(input_ids) < self.seq_length: + self.remain_input_ids = input_ids + self.remain_loss_mask = loss_mask + assert len(self.remain_input_ids) == len(self.remain_loss_mask) + yield {} + else: + cursor = 0 + while cursor + self.seq_length <= len(input_ids): + yield { + "input_ids": input_ids[cursor : cursor + self.seq_length], + "loss_mask": loss_mask[cursor : cursor + self.seq_length], + } + cursor = cursor + self.stride + self.remain_input_ids = input_ids[cursor:] + self.remain_loss_mask = loss_mask[cursor:] + assert len(self.remain_input_ids) == len(self.remain_loss_mask) + yield {} + elif self.args.padding_mode == "pack": + if len(input_ids) > self.seq_length: + yield {} + elif len(self.remain_input_ids) + len(input_ids) > self.seq_length: + input_ids, self.remain_input_ids = self.remain_input_ids, input_ids + loss_mask, self.remain_loss_mask = self.remain_loss_mask, loss_mask + assert len(input_ids) == len(loss_mask) + yield self.padding(input_ids, loss_mask) + else: + self.remain_input_ids = self.remain_input_ids + input_ids + self.remain_loss_mask = self.remain_loss_mask + loss_mask + assert len(self.remain_input_ids) == len(self.remain_loss_mask) + yield {} + + def padding(self, input_ids, loss_mask): + pad_id = Encoder.tokenizer.pad_token_id + assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}" + input_ids += [pad_id] * (self.seq_length - len(input_ids)) + loss_mask += [0] * (self.seq_length - len(loss_mask)) + return {"input_ids": input_ids, "loss_mask": loss_mask} + + +def find_jsonl_fnames(paths): + fnames = [] + for p in paths: + if not os.path.isdir(p): + if p.endswith(".jsonl"): + print(f"loading from {p}") + fnames.append(p) + else: + p_list = glob.glob(p + "/*") + for p_ in p_list: + if p_.endswith(".jsonl"): + print(f"loading from {p_}") + fnames.append(p_) + return fnames diff --git a/mftcoder_accelerate/src/ds_multinode_launch.sh b/mftcoder_accelerate/src/ds_multinode_launch.sh new file mode 100755 index 0000000..dca0670 --- /dev/null +++ b/mftcoder_accelerate/src/ds_multinode_launch.sh @@ -0,0 +1,44 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2024/5/20 +# Description: # Launch script on Multiple Nodes + +# Run this script on all Nodes. + +# You need to export your number of nodes and number of GPUs per node first. +N_NODE=4 +N_GPU_PER_NODE=8 + +# You need to export $MACHINE_RANK, $MASTER_ADDR, $MASTER_PORT automatically for each Node. + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines $N_NODE \ + --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \ + --use_deepspeed \ + --deepspeed_multinode_launcher 'standard' \ + --zero_stage 2 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'none' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag false \ + --zero3_save_16bit_model false \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank $MACHINE_RANK \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" --distributed_type "deepspeed" \ No newline at end of file diff --git a/mftcoder_accelerate/src/ds_single_launch.sh b/mftcoder_accelerate/src/ds_single_launch.sh new file mode 100755 index 0000000..d6c84bb --- /dev/null +++ b/mftcoder_accelerate/src/ds_single_launch.sh @@ -0,0 +1,38 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2023/12/11 +# Description: An alternative(Command line) way to launch DeepSpeed training + +# Launch script on single node +N_GPU_PER_NODE=8 + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines 1 \ + --num_processes $N_GPU_PER_NODE \ + --use_deepspeed \ + --zero_stage 2 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'none' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag false \ + --zero3_save_16bit_model false \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank 0 \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ + --distributed_type "deepspeed" \ + > MFTCoder-training-"$TODAY".log 2>&1 & diff --git a/mftcoder_accelerate/src/ds_zero3_single_launch.sh b/mftcoder_accelerate/src/ds_zero3_single_launch.sh new file mode 100755 index 0000000..5f581c9 --- /dev/null +++ b/mftcoder_accelerate/src/ds_zero3_single_launch.sh @@ -0,0 +1,38 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2024/5/20 +# Description: An alternative(Command line) way to launch DeepSpeed training + +# Launch script on single node +N_GPU_PER_NODE=8 + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines 1 \ + --num_processes $N_GPU_PER_NODE \ + --use_deepspeed \ + --zero_stage 3 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'cpu' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag true \ + --zero3_save_16bit_model true \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank 0 \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ + --distributed_type "deepspeed" \ + > MFTCoder-training-"$TODAY".log 2>&1 & diff --git a/mftcoder_accelerate/src/fsdp_single_launch.sh b/mftcoder_accelerate/src/fsdp_single_launch.sh new file mode 100755 index 0000000..2dc8f89 --- /dev/null +++ b/mftcoder_accelerate/src/fsdp_single_launch.sh @@ -0,0 +1,43 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2023/12/11 +# Description: An alternative(command line) way to launch FSDP training + +# Launch script on single node +N_GPU_PER_NODE=8 + +# config path +CONFIG="configs/xxx_train_config.json" + +# fsdp_transformer_layer_cls_to_wrap, choose the DecoderLayer +WRAP_MODULE="LlamaDecoderLayer" + + + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_fsdp_config.yaml \ +accelerate launch \ + --use_fsdp \ + --num_machines=1 \ + --num_processes=$N_GPU_PER_NODE \ + --fsdp_sharding_strategy=1 \ + --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \ + --fsdp_state_dict_type=FULL_STATE_DICT \ + --fsdp_backward_prefetch_policy=BACKWARD_PRE \ + --fsdp_transformer_layer_cls_to_wrap=$WRAP_MODULE \ + --fsdp_offload_params=false \ + --main_training_function=main \ + --mixed_precision=bf16 \ + --dynamo_backend=no \ + --same_network \ + --machine_rank=0 \ + --rdzv_backend=static \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ + --distributed_type "fsdp" \ + > MFTCoder-training-"$TODAY".log 2>&1 & + diff --git a/mft_peft_hf/src/model/__init__.py b/mftcoder_accelerate/src/model/__init__.py similarity index 100% rename from mft_peft_hf/src/model/__init__.py rename to mftcoder_accelerate/src/model/__init__.py diff --git a/mft_peft_hf/src/model/llama2/configuration_llama.py b/mftcoder_accelerate/src/model/aquila2/configuration_aquila.py similarity index 53% rename from mft_peft_hf/src/model/llama2/configuration_llama.py rename to mftcoder_accelerate/src/model/aquila2/configuration_aquila.py index 83132bd..62097dd 100644 --- a/mft_peft_hf/src/model/llama2/configuration_llama.py +++ b/mftcoder_accelerate/src/model/aquila2/configuration_aquila.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -17,22 +17,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +""" Aquila model configuration""" -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging +from transformers import PretrainedConfig -logger = logging.get_logger(__name__) -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): +class AquilaConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + This is the configuration class to store the configuration of a [`AquilaModel`]. It is used to instantiate an Aquila model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. + defaults will yield a similar configuration to that of the Aquila-7B. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -40,8 +35,8 @@ class LlamaConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] + Vocabulary size of the Aquila model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AquilaModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): @@ -50,19 +45,6 @@ class LlamaConfig(PretrainedConfig): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): @@ -77,35 +59,26 @@ class LlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - Example: ```python - >>> from transformers import LlamaModel, LlamaConfig + >>> from transformers import AquilaModel, AquilaConfig - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() + >>> # Initializing a Aquila aquila-7b style configuration + >>> configuration = AquilaConfig() - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) + >>> # Initializing a model from the aquila-7b style configuration + >>> model = AquilaModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" - model_type = "llama" + model_type = "aquila" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, - vocab_size=32000, + vocab_size=100008, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, @@ -121,6 +94,7 @@ def __init__( eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, + rope_theta=10000.0, rope_scaling=None, use_xformers=True, **kwargs, @@ -130,21 +104,22 @@ def __init__( self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_xformers = use_xformers # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads + + self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache + self.rope_theta = rope_theta self.rope_scaling = rope_scaling - self._rope_scaling_validation() + self.use_xformers = use_xformers super().__init__( pad_token_id=pad_token_id, @@ -154,23 +129,3 @@ def __init__( **kwargs, ) - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/mft_peft_hf/src/model/llama2/modeling_llama.py b/mftcoder_accelerate/src/model/aquila2/modeling_aquila.py similarity index 78% rename from mft_peft_hf/src/model/llama2/modeling_llama.py rename to mftcoder_accelerate/src/model/aquila2/modeling_aquila.py index 6aa560d..d0d9309 100644 --- a/mft_peft_hf/src/model/llama2/modeling_llama.py +++ b/mftcoder_accelerate/src/model/aquila2/modeling_aquila.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -17,12 +17,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMA model.""" +""" PyTorch Aquila model.""" import math from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -31,13 +30,23 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig +from .configuration_aquila import AquilaConfig +from transformers import ( + LogitsProcessorList, + MinLengthLogitsProcessor, + TopKLogitsWarper, + TemperatureLogitsWarper, + TopPLogitsWarper, + StoppingCriteriaList, + MaxLengthCriteria, + BitsAndBytesConfig, +) import xformers.ops logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "LlamaConfig" +_CONFIG_FOR_DOC = "AquilaConfig" # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -48,7 +57,7 @@ def _make_causal_mask( Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -73,10 +82,11 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -class LlamaRMSNorm(nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Aquila +class AquilaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - LlamaRMSNorm is equivalent to T5LayerNorm + AquilaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -84,13 +94,14 @@ def __init__(self, hidden_size, eps=1e-6): def forward(self, hidden_states): input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + + return (self.weight * hidden_states).to(input_dtype) -class LlamaRotaryEmbedding(torch.nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Aquila +class AquilaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -98,7 +109,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( @@ -125,9 +136,9 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Aquila +class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding): + """AquilaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor @@ -144,9 +155,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Aquila +class AquilaDynamicNTKScalingRotaryEmbedding(AquilaRotaryEmbedding): + """AquilaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor @@ -160,7 +171,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -189,10 +200,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed -class LlamaMLP(nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Aquila +class AquilaMLP(nn.Module): def __init__(self, config): super().__init__() - self.pretraining_tp = config.pretraining_tp + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -201,17 +213,21 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) up_proj_slices = self.up_proj.weight.split(slice, dim=0) down_proj_slices = self.down_proj.weight.split(slice, dim=1) - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] down_proj = sum(down_proj) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -231,10 +247,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class LlamaAttention(nn.Module): +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Aquila +class AquilaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): + def __init__(self, config: AquilaConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -242,8 +258,8 @@ def __init__(self, config: LlamaConfig): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta self.use_xformers = config.use_xformers if (self.head_dim * self.num_heads) != self.hidden_size: @@ -259,17 +275,27 @@ def __init__(self, config: LlamaConfig): def _init_rope(self): if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self.rotary_emb = AquilaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.rotary_emb = AquilaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.rotary_emb = AquilaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -288,19 +314,21 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: @@ -334,12 +362,13 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask(), - op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) attn_output = attn_output.contiguous().view(bsz, q_len, -1) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - + attn_weights = torch.clamp(attn_weights, min=-1024., max=1024.) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" @@ -366,10 +395,10 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) @@ -379,14 +408,15 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Aquila +class AquilaDecoderLayer(nn.Module): + def __init__(self, config: AquilaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = AquilaAttention(config=config) + self.mlp = AquilaMLP(config) + self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -442,8 +472,7 @@ def forward( return outputs - -LLAMA_START_DOCSTRING = r""" +AQUILA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -453,7 +482,7 @@ def forward( and behavior. Parameters: - config ([`LlamaConfig`]): + config ([`AquilaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. @@ -461,14 +490,15 @@ def forward( @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare Aquila Model outputting raw hidden-states without any specific head on top.", + AQUILA_START_DOCSTRING, ) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Aquila +class AquilaPreTrainedModel(PreTrainedModel): + config_class = AquilaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] + _no_split_modules = ["AquilaDecoderLayer"] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): @@ -483,11 +513,11 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): + if isinstance(module, AquilaModel): module.gradient_checkpointing = value -LLAMA_INPUTS_DOCSTRING = r""" +AQUILA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -552,25 +582,26 @@ def _set_gradient_checkpointing(self, module, value=False): @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare Aquila Model outputting raw hidden-states without any specific head on top.", + AQUILA_START_DOCSTRING, ) -class LlamaModel(LlamaPreTrainedModel): +# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->AQUILA,Llama->Aquila +class AquilaModel(AquilaPreTrainedModel): """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AquilaDecoderLayer`] Args: - config: LlamaConfig + config: AquilaConfig """ - def __init__(self, config: LlamaConfig): + def __init__(self, config: AquilaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers = nn.ModuleList([AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -582,7 +613,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -606,7 +636,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em return combined_attention_mask - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -689,7 +719,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, output_attentions, None) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -698,7 +728,6 @@ def custom_forward(*inputs): hidden_states, attention_mask, position_ids, - None, ) else: layer_outputs = decoder_layer( @@ -734,14 +763,13 @@ def custom_forward(*inputs): attentions=all_self_attns, ) - -class LlamaForCausalLM(LlamaPreTrainedModel): +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->AQUILA,Llama->Aquila +class AquilaForCausalLM(AquilaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp + self.model = AquilaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -766,7 +794,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -793,18 +821,18 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from transformers import AutoTokenizer, AquilaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> model = AquilaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -827,9 +855,9 @@ def forward( ) hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) @@ -899,12 +927,118 @@ def _reorder_cache(past_key_values, beam_idx): ) return reordered_past + def predict(self, text, tokenizer=None, + max_gen_len=200, top_p=0.95, + seed=1234, topk=100, + temperature=0.9, + sft=True, convo_template = "aquila-chat", + device = "cuda"): + + vocab = tokenizer.get_vocab() + #device = device + id2word = {v:k for k, v in vocab.items()} + + + set_random_seed(seed) + if temperature == 0: + topk = 1 + temperature = 1.0 + if sft: + tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=convo_template) + tokens = torch.tensor(tokens)[None,].to(device) + else : + tokens = tokenizer.encode_plus(text)["input_ids"] + print(tokenizer.decode(tokens)) + tokens = torch.tensor(tokens)[None,].to(device) + input_length = len(tokens[0]) + with torch.no_grad(): + + # instantiate logits processors + logits_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(1, eos_token_id=100007), + ] + ) + # instantiate logits processors + logits_warper = LogitsProcessorList( + [ + TopPLogitsWarper(top_p), + TopKLogitsWarper(topk), + TemperatureLogitsWarper(temperature), + + ] + ) + + stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=input_length + max_gen_len)]) + out = self.sample( + tokens, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + return_dict_in_generate=True, + output_scores=True, + ) + + + # print(out) + out_ids = out["sequences"][0][input_length:].cpu().numpy() + + out_scores = out["scores"] + + out_scores = torch.cat(out_scores, dim=0) + out_scores = torch.nn.functional.softmax(out_scores, dim=-1).cpu().numpy() + + probs = [] + for i in range(len(out_ids)): + probs.append(float(out_scores[i][out_ids[i]])) + + # print(f"probs is {probs}") + + convert_tokens = [] + for t in out_ids: + if t == 100006: + convert_tokens.append("[CLS]") + else : + convert_tokens.append(id2word.get(t, "[unkonwn_token]")) + + out_text = tokenizer.decode(out_ids.tolist()) + + + out = out_text + + if "###" in out: + special_index = out.index("###") + out = out[: special_index] + token_length = len(tokenizer.encode_plus(out)["input_ids"]) + convert_tokens = convert_tokens[:token_length] + probs = probs[:token_length] + + if "[UNK]" in out: + special_index = out.index("[UNK]") + out = out[:special_index] + token_length = len(tokenizer.encode_plus(out)["input_ids"]) + convert_tokens = convert_tokens[:token_length] + probs = probs[:token_length] + + if "" in out: + special_index = out.index("": self.tokenizer.eos_id, " ": self.tokenizer.pad_id } + super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) def get_command(self, token): if token in self.special_tokens: diff --git a/mft_peft_hf/src/model/chatglm2/tokenizer_config.json b/mftcoder_accelerate/src/model/chatglm2/tokenizer_config.json similarity index 100% rename from mft_peft_hf/src/model/chatglm2/tokenizer_config.json rename to mftcoder_accelerate/src/model/chatglm2/tokenizer_config.json diff --git a/mftcoder_accelerate/src/model/chatglm3/config.json b/mftcoder_accelerate/src/model/chatglm3/config.json new file mode 100644 index 0000000..37933c8 --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/config.json @@ -0,0 +1,42 @@ +{ + "_name_or_path": "THUDM/chatglm3-6b", + "model_type": "chatglm", + "architectures": [ + "ChatGLMModel" + ], + "auto_map": { + "AutoConfig": "configuration_chatglm.ChatGLMConfig", + "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification" + }, + "add_bias_linear": false, + "add_qkv_bias": true, + "apply_query_key_layer_scaling": true, + "apply_residual_connection_post_layernorm": false, + "attention_dropout": 0.0, + "attention_softmax_in_fp32": true, + "bias_dropout_fusion": true, + "ffn_hidden_size": 13696, + "fp32_residual_connection": false, + "hidden_dropout": 0.0, + "hidden_size": 4096, + "kv_channels": 128, + "layernorm_epsilon": 1e-05, + "multi_query_attention": true, + "multi_query_group_num": 2, + "num_attention_heads": 32, + "num_layers": 28, + "original_rope": true, + "padded_vocab_size": 65024, + "post_layer_norm": true, + "rmsnorm": true, + "seq_length": 8192, + "use_cache": true, + "torch_dtype": "float16", + "transformers_version": "4.30.2", + "tie_word_embeddings": false, + "eos_token_id": 2, + "pad_token_id": 0 +} \ No newline at end of file diff --git a/mftcoder_accelerate/src/model/chatglm3/configuration_chatglm.py b/mftcoder_accelerate/src/model/chatglm3/configuration_chatglm.py new file mode 100644 index 0000000..3560018 --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/configuration_chatglm.py @@ -0,0 +1,61 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + classifier_dropout=None, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.classifier_dropout = classifier_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) \ No newline at end of file diff --git a/mftcoder_accelerate/src/model/chatglm3/modeling_chatglm.py b/mftcoder_accelerate/src/model/chatglm3/modeling_chatglm.py new file mode 100644 index 0000000..e75568b --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/modeling_chatglm.py @@ -0,0 +1,1293 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +from copy import deepcopy + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" +_CONFIG_FOR_DOC = "ChatGLMConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm3-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, **_config_to_kwargs(config) + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, + device=device, **_config_to_kwargs(config) + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, + device=input_ids.device), full_attention_mask), dim=-1) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, + dtype=config.torch_dtype, **init_kwargs) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + is_first_forward: bool = True, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + if past_key_values is not None: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + "use_cache": use_cache + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, output, history): + content = "" + history = deepcopy(history) + for response in output.split("<|assistant|>"): + metadata, content = response.split("\n", maxsplit=1) + if not metadata.strip(): + content = content.strip() + history.append({"role": "assistant", "metadata": metadata, "content": content}) + content = content.replace("[[训练时间]]", "2023年") + else: + history.append({"role": "assistant", "metadata": metadata, "content": content}) + if history[0]["role"] == "system" and "tools" in history[0]: + content = "\n".join(content.split("\n")[1:-1]) + def tool_call(**kwargs): + return kwargs + parameters = eval(content) + content = {"name": metadata.strip(), "parameters": parameters} + else: + content = {"name": metadata.strip(), "content": content} + return content, history + + @torch.inference_mode() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", + max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + inputs = tokenizer.build_chat_input(query, history=history, role=role) + inputs = inputs.to(self.device) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + history.append({"role": role, "content": query}) + response, history = self.process_response(response, history) + return response, history + + @torch.inference_mode() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", + past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, + logits_processor=None, return_past_key_values=False, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if past_key_values is None: + inputs = tokenizer.build_chat_input(query, history=history, role=role) + else: + inputs = tokenizer.build_chat_input(query, role=role) + inputs = inputs.to(self.device) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs['attention_mask'] = attention_mask + history.append({"role": role, "content": query}) + for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, + eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, + **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response, new_history = self.process_response(response, history) + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.inference_mode() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + model_kwargs["use_cache"] = generation_config.use_cache + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, + **kwargs) + return self + + +class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.num_labels = config.num_labels + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + + self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) + if config.classifier_dropout is not None: + self.dropout = nn.Dropout(config.classifier_dropout) + else: + self.dropout = None + self.config = config + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + full_attention_mask=full_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + pooled_hidden_states = hidden_states[-1] + if self.dropout is not None: + pooled_hidden_states = self.dropout(pooled_hidden_states) + logits = self.classifier_head(pooled_hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze().float(), labels.squeeze()) + else: + loss = loss_fct(logits.float(), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/mftcoder_accelerate/src/model/chatglm3/quantization.py b/mftcoder_accelerate/src/model/chatglm3/quantization.py new file mode 100644 index 0000000..cb95bfe --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/quantization.py @@ -0,0 +1,188 @@ +from torch.nn import Linear +from torch.nn.parameter import Parameter + +import bz2 +import torch +import base64 +import ctypes +from transformers.utils import logging + +from typing import List +from functools import partial + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + "int4WeightCompression", + "int4WeightExtractionFloat", + "int4WeightExtractionHalf", + "int8WeightExtractionFloat", + "int8WeightExtractionHalf", + ], + ) +except Exception as exception: + kernels = None + logger.warning("Failed to load cpm_kernels:" + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features,))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): + assert scale_list.dtype in [torch.half, torch.bfloat16] + assert weight.dtype in [torch.int8] + if source_bit_width == 8: + return weight.to(scale_list.dtype) * scale_list[:, None] + elif source_bit_width == 4: + func = ( + kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16 + ) + else: + assert False, "Unsupported bit-width" + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(torch.nn.Module): + def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, + **kwargs): + super().__init__() + self.weight_bit_width = weight_bit_width + + shape = weight.shape + + if weight is None or empty_init: + self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device) + self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device) + else: + self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) + self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(device), requires_grad=False) + self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False) + self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, device=None): + """Replace fp16 linear with quantized linear""" + for layer in model.layers: + layer.self_attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.query_key_value.bias, + dtype=layer.self_attention.query_key_value.weight.dtype, + device=layer.self_attention.query_key_value.weight.device if device is None else device, + empty_init=empty_init + ) + layer.self_attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.dense.bias, + dtype=layer.self_attention.dense.weight.dtype, + device=layer.self_attention.dense.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_h_to_4h.bias, + dtype=layer.mlp.dense_h_to_4h.weight.dtype, + device=layer.mlp.dense_h_to_4h.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_4h_to_h.bias, + dtype=layer.mlp.dense_4h_to_h.weight.dtype, + device=layer.mlp.dense_4h_to_h.weight.device if device is None else device, + empty_init=empty_init + ) + + return model diff --git a/mftcoder_accelerate/src/model/chatglm3/tokenization_chatglm.py b/mftcoder_accelerate/src/model/chatglm3/tokenization_chatglm.py new file mode 100644 index 0000000..e50d329 --- /dev/null +++ b/mftcoder_accelerate/src/model/chatglm3/tokenization_chatglm.py @@ -0,0 +1,283 @@ +import json +import os +import torch +from typing import List, Optional, Union, Dict +from sentencepiece import SentencePieceProcessor +from transformers import PreTrainedTokenizer +from transformers.utils import logging, PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding + + +class SPTokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.unk_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop", "<|system|>", "<|user|>", "<|assistant|>", + "<|observation|>"] + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + + def tokenize(self, s: str): + return self.sp_model.EncodeAsPieces(s) + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + text, buffer = "", [] + for token in t: + if token in self.index_special_tokens: + if buffer: + text += self.sp_model.decode(buffer) + buffer = [] + text += self.index_special_tokens[token] + else: + buffer.append(token) + if buffer: + text += self.sp_model.decode(buffer) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return self.index_special_tokens[index] + if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: + return "" + return self.sp_model.IdToPiece(index) + + +class ChatGLMTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "tokenizer.model"} + + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs): + self.name = "GLMTokenizer" + + self.vocab_file = vocab_file + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + " ": self.tokenizer.bos_id, + " ": self.tokenizer.eos_id, + " ": self.tokenizer.pad_id + } + super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" + return self.tokenizer.special_tokens[token] + + @property + def unk_token(self) -> str: + return " " + + @property + def pad_token(self) -> str: + return " " + + @property + def pad_token_id(self): + return self.get_command(" ") + + @property + def eos_token(self) -> str: + return " ") + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, self.vocab_files_names["vocab_file"] + ) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def get_prefix_tokens(self): + prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")] + return prefix_tokens + + def build_single_message(self, role, metadata, message): + assert role in ["system", "user", "assistant", "observation"], role + role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") + message_tokens = self.tokenizer.encode(message) + tokens = role_tokens + message_tokens + return tokens + + def build_chat_input(self, query, history=None, role="user"): + if history is None: + history = [] + input_ids = [] + for item in history: + content = item["content"] + if item["role"] == "system" and "tools" in item: + content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) + input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) + input_ids.extend(self.build_single_message(role, "", query)) + input_ids.extend([self.get_command("<|assistant|>")]) + return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + prefix_tokens = self.get_prefix_tokens() + token_ids_0 = prefix_tokens + token_ids_0 + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command(" ")] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * seq_length + + if "position_ids" not in encoded_inputs: + encoded_inputs["position_ids"] = list(range(seq_length)) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/mft_peft_hf/src/model/code_llama/__init__.py b/mftcoder_accelerate/src/model/code_llama/__init__.py similarity index 51% rename from mft_peft_hf/src/model/code_llama/__init__.py rename to mftcoder_accelerate/src/model/code_llama/__init__.py index cfeaa6b..c2d50fb 100644 --- a/mft_peft_hf/src/model/code_llama/__init__.py +++ b/mftcoder_accelerate/src/model/code_llama/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# Copyright 2023 MetaAI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +13,10 @@ # limitations under the License. from typing import TYPE_CHECKING -from transformers.utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_sentencepiece_available, - is_tokenizers_available, - is_torch_available, -) +from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_tokenizers_available -_import_structure = { - "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], -} +_import_structure = {} try: if not is_sentencepiece_available(): @@ -32,7 +24,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["tokenization_llama"] = ["LlamaTokenizer"] + _import_structure["tokenization_code_llama"] = ["CodeLlamaTokenizer"] try: if not is_tokenizers_available(): @@ -40,32 +32,16 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"] - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_llama"] = [ - "LlamaForCausalLM", - "LlamaModel", - "LlamaPreTrainedModel", - "LlamaForSequenceClassification", - ] - + _import_structure["tokenization_code_llama_fast"] = ["CodeLlamaTokenizerFast"] if TYPE_CHECKING: - from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig - try: if not is_sentencepiece_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: pass else: - from .tokenization_llama import LlamaTokenizer + from .tokenization_code_llama import CodeLlamaTokenizer try: if not is_tokenizers_available(): @@ -73,16 +49,7 @@ except OptionalDependencyNotAvailable: pass else: - from .tokenization_llama_fast import LlamaTokenizerFast - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel - + from .tokenization_code_llama_fast import CodeLlamaTokenizerFast else: import sys diff --git a/mft_peft_hf/src/model/code_llama/configuration_llama.py b/mftcoder_accelerate/src/model/code_llama/configuration_llama.py similarity index 95% rename from mft_peft_hf/src/model/code_llama/configuration_llama.py rename to mftcoder_accelerate/src/model/code_llama/configuration_llama.py index 430bf9f..2c914c8 100644 --- a/mft_peft_hf/src/model/code_llama/configuration_llama.py +++ b/mftcoder_accelerate/src/model/code_llama/configuration_llama.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +"""4.33.1 LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -66,8 +66,8 @@ class LlamaConfig(PretrainedConfig): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-12): @@ -77,6 +77,8 @@ class LlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format @@ -121,6 +123,7 @@ def __init__( eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, + rope_theta=10000.0, rope_scaling=None, use_xformers=True, **kwargs, @@ -133,7 +136,6 @@ def __init__( self.num_attention_heads = num_attention_heads self.use_xformers = use_xformers - # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -144,6 +146,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache + self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() @@ -164,14 +167,14 @@ def _rope_scaling_validation(self): if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/mft_peft_hf/src/model/code_llama/convert_llama_weights_to_hf.py b/mftcoder_accelerate/src/model/code_llama/convert_llama_weights_to_hf.py similarity index 90% rename from mft_peft_hf/src/model/code_llama/convert_llama_weights_to_hf.py rename to mftcoder_accelerate/src/model/code_llama/convert_llama_weights_to_hf.py index 03c1eb4..acc4988 100644 --- a/mft_peft_hf/src/model/code_llama/convert_llama_weights_to_hf.py +++ b/mftcoder_accelerate/src/model/code_llama/convert_llama_weights_to_hf.py @@ -53,18 +53,12 @@ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). """ -INTERMEDIATE_SIZE_MAP = { - "7B": 11008, - "13B": 13824, - "30B": 17920, - "65B": 22016, - "70B": 28672, -} NUM_SHARDS = { "7B": 1, "7Bf": 1, "13B": 2, "13Bf": 2, + "34B": 4, "30B": 4, "65B": 8, "70B": 8, @@ -86,7 +80,11 @@ def write_json(text, path): json.dump(text, f) -def write_model(model_path, input_base_path, model_size, safe_serialization=True): +def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): + # for backward compatibility, before you needed the repo to be called `my_repo/model_size` + if not os.path.isfile(os.path.join(input_base_path, "params.json")): + input_base_path = os.path.join(input_base_path, model_size) + os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") os.makedirs(tmp_model_path, exist_ok=True) @@ -98,8 +96,18 @@ def write_model(model_path, input_base_path, model_size, safe_serialization=True n_heads_per_shard = n_heads // num_shards dim = params["dim"] dims_per_head = dim // n_heads - base = 10000.0 + base = params.get("rope_theta", 10000.0) inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + if base > 10000.0: + max_position_embeddings = 16384 + else: + max_position_embeddings = 2048 + + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + if tokenizer_path is not None: + tokenizer = tokenizer_class(tokenizer_path) + tokenizer.save_pretrained(model_path) + vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 if "n_kv_heads" in params: num_key_value_heads = params["n_kv_heads"] # for GQA / MQA @@ -247,6 +255,9 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): num_hidden_layers=params["n_layers"], rms_norm_eps=params["norm_eps"], num_key_value_heads=num_key_value_heads, + vocab_size=vocab_size, + rope_theta=base, + max_position_embeddings=max_position_embeddings, ) config.save_pretrained(tmp_model_path) @@ -256,10 +267,10 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): gc.collect() print("Loading the checkpoint in a Llama model.") - model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) # Avoid saving this as part of the config. del model.config._name_or_path - + model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) shutil.rmtree(tmp_model_path) @@ -281,7 +292,7 @@ def main(): ) parser.add_argument( "--model_size", - choices=["7B", "7Bf", "13B", "13Bf", "30B", "65B", "70B", "70Bf", "tokenizer_only"], + choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", ) parser.add_argument( @@ -290,15 +301,17 @@ def main(): ) parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") args = parser.parse_args() + spm_path = os.path.join(args.input_dir, "tokenizer.model") if args.model_size != "tokenizer_only": write_model( model_path=args.output_dir, - input_base_path=os.path.join(args.input_dir, args.model_size), + input_base_path=args.input_dir, model_size=args.model_size, safe_serialization=args.safe_serialization, + tokenizer_path=spm_path, ) - spm_path = os.path.join(args.input_dir, "tokenizer.model") - write_tokenizer(args.output_dir, spm_path) + else: + write_tokenizer(args.output_dir, spm_path) if __name__ == "__main__": diff --git a/mft_peft_hf/src/model/code_llama/modeling_llama.py b/mftcoder_accelerate/src/model/code_llama/modeling_llama.py similarity index 97% rename from mft_peft_hf/src/model/code_llama/modeling_llama.py rename to mftcoder_accelerate/src/model/code_llama/modeling_llama.py index fde3a9c..7dea342 100644 --- a/mft_peft_hf/src/model/code_llama/modeling_llama.py +++ b/mftcoder_accelerate/src/model/code_llama/modeling_llama.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMA model.""" +"""4.33.1 PyTorch LLaMA model.""" import math from typing import List, Optional, Tuple, Union @@ -247,6 +247,7 @@ def __init__(self, config: LlamaConfig): self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta self.use_xformers = config.use_xformers if (self.head_dim * self.num_heads) != self.hidden_size: @@ -262,17 +263,27 @@ def __init__(self, config: LlamaConfig): def _init_rope(self): if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -339,8 +350,9 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask(), - op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) attn_output = attn_output.contiguous().view(bsz, q_len, -1) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) diff --git a/mft_peft_hf/src/model/llama2/tokenization_llama.py b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama.py similarity index 50% rename from mft_peft_hf/src/model/llama2/tokenization_llama.py rename to mftcoder_accelerate/src/model/code_llama/tokenization_code_llama.py index dd1c936..0c99a02 100644 --- a/mft_peft_hf/src/model/llama2/tokenization_llama.py +++ b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama.py @@ -1,10 +1,6 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved. # -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for LLaMA.""" +"""4.33.1 Tokenization classes for Code LLaMA.""" import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import sentencepiece as spm +from transformers.convert_slow_tokenizer import import_protobuf from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer -from transformers.utils import logging +from transformers.utils import logging, requires_backends if TYPE_CHECKING: @@ -38,14 +35,14 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-code-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-code-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "hf-internal-testing/llama-tokenizer": 2048, + "hf-internal-testing/llama-code-tokenizer": 2048, } SPIECE_UNDERLINE = "▁" @@ -53,46 +50,71 @@ B_SYS, E_SYS = "< >\n", "\n< >\n\n" # fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\ +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ -that your responses are socially unbiased and positive in nature. + that your responses are socially unbiased and positive in nature. -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\ +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ correct. If you don't know the answer to a question, please don't share false information.""" # fmt: on -class LlamaTokenizer(PreTrainedTokenizer): +class CodeLlamaTokenizer(PreTrainedTokenizer): """ - Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is - no padding token in the original model. + Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as + there is no padding token in the original model. + + The default configuration match that of + [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json) + which supports prompt infilling. Args: vocab_file (`str`): Path to the vocabulary file. - legacy (`bool`, *optional*, defaults to `True`): - Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622 - which includes fixes to properly handle tokens that appear after special tokens. A simple example: - - - `legacy=True`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) - >>> tokenizer.encode("Hello.") - [8774, 32099, 3, 5, 1] - ``` - - `legacy=False`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) - >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here - [8774, 32099, 5, 1] - ``` - Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for - more details. - + eos_token (`str`, *optional*, defaults to `"
"`): + Prefix token used for infilling. + suffix_token (`str`, *optional*, defaults to `"▁"`): + Suffix token used for infilling. + middle_token (`str`, *optional*, defaults to `"▁ "`): + Middle token used for infilling. + eot_token (`str`, *optional*, defaults to `"▁ "`): + End of text token used for infilling. + fill_token (`str`, *optional*, defaults to `" "`): + The token used to split the input between the prefix and suffix. + suffix_first (`bool`, *optional*, default to `False`): + Whether the input prompt and suffix should be formatted with the suffix first. + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Llama should be used. """ vocab_files_names = VOCAB_FILES_NAMES @@ -106,99 +128,184 @@ def __init__( unk_token=" ", bos_token=" ", eos_token="", - pad_token=None, + prefix_token="▁", + middle_token="▁", + suffix_token="▁ ", + eot_token="▁ ", + fill_token=" ", + suffix_first=False, sp_model_kwargs: Optional[Dict[str, Any]] = None, add_bos_token=True, add_eos_token=False, clean_up_tokenization_spaces=False, - legacy=True, + additional_special_tokens=None, + use_default_system_prompt=False, **kwargs, ): + requires_backends(self, "protobuf") self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + self.use_default_system_prompt = use_default_system_prompt + # mark tokens special to skip them + additional_special_tokens = additional_special_tokens or [] + for token in [prefix_token, middle_token, suffix_token, eot_token]: + additional_special_tokens += [token] if token is not None else [] + super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, - pad_token=pad_token, add_bos_token=add_bos_token, add_eos_token=add_eos_token, + prefix_token=prefix_token, + middle_token=middle_token, + suffix_token=suffix_token, + eot_token=eot_token, + fill_token=fill_token, sp_model_kwargs=self.sp_model_kwargs, + suffix_first=suffix_first, clean_up_tokenization_spaces=clean_up_tokenization_spaces, - legacy=legacy, + additional_special_tokens=additional_special_tokens, + use_default_system_prompt=use_default_system_prompt, **kwargs, ) - if legacy: - logger.warning_once( - f"You are using the legacy behaviour of the {self.__class__}. This means that tokens that come after special tokens will not be properly handled. We recommend you to" - " read the related pull request available at https://github.com/huggingface/transformers/pull/24565" - ) - self.legacy = legacy self.vocab_file = vocab_file self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(vocab_file) + self._prefix_token = prefix_token + self._middle_token = middle_token + self._suffix_token = suffix_token + self._eot_token = eot_token + self.fill_token = fill_token + self.suffix_first = suffix_first + self.sp_model = self.get_spm_processor() - def __getstate__(self): - state = self.__dict__.copy() - state["sp_model"] = None - state["sp_model_proto"] = self.sp_model.serialized_model_proto() - return state + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def get_spm_processor(self): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf() + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer - def __setstate__(self, d): - self.__dict__ = d - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + @property + def prefix_token(self): + return self._prefix_token + + @property + def prefix_id(self): + if self._prefix_token is None: + return None + return self.convert_tokens_to_ids(self.prefix_token) + + @property + def middle_token(self): + return self._middle_token + + @property + def middle_id(self): + if self._middle_token is None: + return None + return self.convert_tokens_to_ids(self.middle_token) + + @property + def suffix_token(self): + return self._suffix_token + + @property + def suffix_id(self): + if self._suffix_token is None: + return None + return self.convert_tokens_to_ids(self.suffix_token) + + @property + def eot_token(self): + return self._eot_token + + @property + def eot_id(self): + if self._eot_token is None: + return None + return self.convert_tokens_to_ids(self.eot_token) @property def vocab_size(self): """Returns vocab size""" return self.sp_model.get_piece_size() + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab def get_vocab(self): """Returns vocab as a dict""" vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize - def tokenize(self, text, **kwargs) -> List[str]: - # Replace the SPIECE_UNDERLINE with a space to make sure SPIECE_UNDERLINE is only used at - # the beginning of the text - if not self.legacy: - text = SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " ") - return super().tokenize(text, **kwargs) + def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> List[int]: + # add a prefix space to `prefix` + if self.fill_token is not None and self.fill_token in prefix and suffix is None: + prefix, suffix = prefix.split(self.fill_token) + + if len(prefix) > 0: + prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ") + + if suffix is None or len(suffix) < 1: + tokens = super().tokenize(prefix, **kwargs) + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + prefix_tokens = self._tokenize(prefix) # prefix has an extra `SPIECE_UNDERLINE` + + if None in (self.prefix_id, self.middle_id, self.suffix_id): + raise ValueError( + "The input either includes a `prefix` and a `suffix` used for the infilling task," + f" or can be split on the {self.fill_token} token, creating a suffix and prefix," + " but the model does not support `infilling`." + ) + suffix_tokens = self._tokenize(suffix) # make sure CodeLlama sp model does not mess up - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize - def _tokenize(self, text): + suffix_first = suffix_first if suffix_first is not None else self.suffix_first + if suffix_first: + # format as " {suf} {pre}" + return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens + else: + # format as " {pre}{suf} " + return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token] + + def _tokenize(self, text, **kwargs): """ Returns a tokenized string. - Since the sentencepiece internal model always adds a SPIECE_UNDERLINE, at the beginning of the provided text, - we need to remove it by hand when the current text is a subsequence. This happens whenever the `self.tokenize` - function is called with specials tokens: the input is split on the special tokens, and each subsequence is - passed to `_tokenize`. Thus if a subsequence did not start with a `" "` or SPIECE_UNDERLINE, we have to remove - the extra `SPIECE_UNDERLINE` prepended. + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = " "` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. """ - if not self.legacy: - is_first = text.startswith(SPIECE_UNDERLINE) - if is_first: - text = text[1:] - tokens = self.sp_model.encode(text, out_type=str) + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens - if not self.legacy and not is_first and not text.startswith(" ") and tokens[0].startswith(SPIECE_UNDERLINE): - tokens = ([tokens[0][1:]] if len(tokens[0]) > 1 else []) + tokens[1:] - return tokens - + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.sp_model.piece_to_id(token) + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self.sp_model.IdToPiece(index) @@ -206,23 +313,23 @@ def _convert_id_to_token(self, index): def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE): + tokens[0] = tokens[0][1:] + current_sub_tokens = [] out_string = "" - prev_is_special = False - for i, token in enumerate(tokens): + for _, token in enumerate(tokens): # make sure that special tokens are not decoded using sentencepiece model if token in self.all_special_tokens: - if not prev_is_special and i != 0: - out_string += " " out_string += self.sp_model.decode(current_sub_tokens) + token - prev_is_special = True current_sub_tokens = [] else: current_sub_tokens.append(token) - prev_is_special = False out_string += self.sp_model.decode(current_sub_tokens) return out_string + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. @@ -250,6 +357,7 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) return (out_vocab_file,) + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): bos_token_id = [self.bos_token_id] if self.add_bos_token else [] eos_token_id = [self.eos_token_id] if self.add_eos_token else [] @@ -261,6 +369,7 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: @@ -298,6 +407,7 @@ def get_special_tokens_mask( + eos_token_id ) + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: @@ -332,7 +442,7 @@ def create_token_type_ids_from_sequences( return output def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """Builds the input ids for a conversation. + r"""Builds the input ids for a conversation. This is the format used in the provided examples. System prompts should be manually added at the beginning of the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. ``` @@ -346,8 +456,8 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in >>> from transformers import Conversation >>> Conversation( - ... "< >\n Only answer with emojis, and charades\n< >\n\nHow can I build a house in 10 septs?" - ... ) + ... "<>\n Complete the functions without any documentation\n< >\n\n `def remove_non_ascii(s: str) -> str:`" + ... ) # doctest: +IGNORE_RESULT ``` Args: conversation (`Conversation`): @@ -356,6 +466,21 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in `List[int]`: Input ids for the conversation. """ + if self.use_default_system_prompt: + if len(conversation.past_user_inputs) > 0: + if ( + not conversation.past_user_inputs[0].startswith(B_SYS) + or E_SYS not in conversation.past_user_inputs[0] + ): + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") + dialogue = list(conversation.iter_texts()) if not all([is_user for is_user, msg in dialogue[::2]]) or not all( [not is_user for is_user, msg in dialogue[1::2]] @@ -365,14 +490,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in ) dialog_tokens: List[int] = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1]) - dialog_tokens += sum( [ [self.bos_token_id] @@ -384,9 +501,18 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in ], [], ) - if not (dialogue[-1][0]): - raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") dialog_tokens += [self.bos_token_id] + self.encode( f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False ) return dialog_tokens + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) diff --git a/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama_fast.py b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama_fast.py new file mode 100644 index 0000000..b492b65 --- /dev/null +++ b/mftcoder_accelerate/src/model/code_llama/tokenization_code_llama_fast.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from shutil import copyfile +from typing import TYPE_CHECKING, List, Optional, Tuple + +from tokenizers import normalizers, processors + +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import is_sentencepiece_available, logging +from transformers.utils.versions import require_version + + +if TYPE_CHECKING: + from transformers.pipelines.conversational import Conversation + +require_version("tokenizers>=0.13.3") + +if is_sentencepiece_available(): + from .tokenization_code_llama import CodeLlamaTokenizer +else: + CodeLlamaTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} + +SPIECE_UNDERLINE = "▁" + + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n< >\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class CodeLlamaTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. + + This uses notably ByteFallback and no normalization. + + ```python + >>> from transformers import CodeLlamaTokenizerFast + + >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + >>> tokenizer.encode("Hello this is a test") + [1, 15043, 445, 338, 263, 1243] + ``` + + If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or + call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the + values of the first token and final token of an encoded sequence will not be correct). For more details, checkout + [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. The default configuration match that of + [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json) + which supports prompt infilling. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that + contains the vocabulary necessary to instantiate a tokenizer. + tokenizer_file (`str`): + [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): + Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra + spaces. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + prefix_token (`str`, *optional*, defaults to `"▁ "`): + Prefix token used for infilling. + suffix_token (`str`, *optional*, defaults to `"▁"`): + Suffix token used for infilling. + middle_token (`str`, *optional*, defaults to `"▁ "`): + Middle token used for infilling. + eot_token (`str`, *optional*, defaults to `"▁ "`): + End of text token used for infilling. + fill_token (`str`, *optional*, defaults to `" "`): + The token used to split the input between the prefix and suffix. + suffix_first (`bool`, *optional*, default to `False`): + Whether the input prompt and suffix should be formatted with the suffix first. + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + use_default_system_prompt (`bool`, *optional*, defaults to `True`): + Whether or not the default system prompt for Llama should be used. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = CodeLlamaTokenizer + padding_side = "left" + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token=" ", + bos_token=" ", + eos_token="", + prefix_token="▁", + middle_token="▁", + suffix_token="▁ ", + eot_token="▁ ", + fill_token=" ", + additional_special_tokens=None, + add_bos_token=True, + add_eos_token=False, + use_default_system_prompt=False, + **kwargs, + ): + # mark tokens special to skip them + additional_special_tokens = additional_special_tokens or [] + for token in [prefix_token, middle_token, suffix_token, eot_token]: + additional_special_tokens += [token] if token is not None else [] + self.use_default_system_prompt = use_default_system_prompt + + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + additional_special_tokens=additional_special_tokens, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + prefix_token=prefix_token, + middle_token=middle_token, + suffix_token=suffix_token, + eot_token=eot_token, + fill_token=fill_token, + use_default_system_prompt=use_default_system_prompt, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + + self.vocab_file = vocab_file + + self._prefix_token = prefix_token + self._middle_token = middle_token + self._suffix_token = suffix_token + self._eot_token = eot_token + self.fill_token = fill_token + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + + eos = self.eos_token + eos_token_id = self.eos_token_id + + single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}" + pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def prefix_token(self): + return self._prefix_token + + @property + def prefix_id(self): + if self._prefix_token is None: + return None + return self.convert_tokens_to_ids(self.prefix_token) + + @property + def middle_token(self): + return self._middle_token + + @property + def middle_id(self): + if self._middle_token is None: + return None + return self.convert_tokens_to_ids(self.middle_token) + + @property + def suffix_token(self): + return self._suffix_token + + @property + def suffix_id(self): + if self._suffix_token is None: + return None + return self.convert_tokens_to_ids(self.suffix_token) + + @property + def eot_id(self): + if self._eot_token is None: + return None + return self.convert_tokens_to_ids(self.eot_token) + + @property + def eot_token(self): + return self._eot_token + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True): + if reset: + self._tokenizer.normalizer = normalizers.Sequence( + [ + normalizers.Prepend(prepend="▁"), + normalizers.Replace(pattern=" ", content="▁"), + ] + ) + self.update_post_processor() + + self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁") + pair = [self.bos_token] if self.add_bos_token and add_special_tokens else [] + special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else [] + if suffix_first: + # format as " {suf} {pre}" + pair += [self.prefix_token, self.suffix_token, "$A", self.middle_token, "$B"] + special_tokens += [ + (self.prefix_token, self.prefix_id), + (self.suffix_token, self.suffix_id), + (self.middle_token, self.middle_id), + ] + else: + # format as " {pre}{suf} " + pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token] + special_tokens += [ + (self.prefix_token, self.prefix_id), + (self.suffix_token, self.suffix_id), + (self.middle_token, self.middle_id), + ] + + if self.add_eos_token and add_special_tokens: + pair += [self.eos_token] + special_tokens += [(self.eos_token, self.eos_token_id)] + self._tokenizer.post_processor = processors.TemplateProcessing( + single="$A", pair=pair, special_tokens=special_tokens + ) + + def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs): + # hack to make sure the input is pre-process but outside rust + text_pair = kwargs.pop("suffix", text_pair) + if self.fill_token in text and text_pair is None: + text, text_pair = text.split(self.fill_token) + + if text_pair is None or len(text_pair) < 1: + return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs) + + if None in (self.prefix_id, self.middle_id, self.suffix_id): + raise ValueError( + "Then input includes a `prefix` and a `suffix` used for the infilling task," + " the `prefix_id, middle_id, suffix_id` must all be initialized. Current" + f" values : {self.prefix_id, self.middle_id, self.suffix_id}" + ) + + self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens) + tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs) + self.set_infilling_processor(True) + return tokens + + # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. The special tokens depend on calling set_lang. + + An NLLB sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.bos_token_id + token_ids_0 + self.eos_token_id + return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id + + # Copied from transformers.models.code_llama.tokenization_code_llama.CodeLlamaTokenizer._build_conversation_input_ids + def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: + r"""Builds the input ids for a conversation. + This is the format used in the provided examples. System prompts should be manually added at the beginning of + the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. + ``` + [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer + [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + ``` + + If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: + ```python + >>> from transformers import Conversation + + >>> Conversation( + ... "< >\n Complete the functions without any documentation\n< >\n\n `def remove_non_ascii(s: str) -> str:`" + ... ) # doctest: +IGNORE_RESULT + ``` + Args: + conversation (`Conversation`): + Conversation to build input ids for. + Returns: + `List[int]`: + Input ids for the conversation. + """ + if self.use_default_system_prompt: + if len(conversation.past_user_inputs) > 0: + if ( + not conversation.past_user_inputs[0].startswith(B_SYS) + or E_SYS not in conversation.past_user_inputs[0] + ): + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") + + dialogue = list(conversation.iter_texts()) + if not all([is_user for is_user, msg in dialogue[::2]]) or not all( + [not is_user for is_user, msg in dialogue[1::2]] + ): + raise ValueError( + "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" + ) + + dialog_tokens: List[int] = [] + dialog_tokens += sum( + [ + [self.bos_token_id] + + self.encode( + f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False + ) + + [self.eos_token_id] + for prompt, answer in zip(dialogue[::2], dialogue[1::2]) + ], + [], + ) + dialog_tokens += [self.bos_token_id] + self.encode( + f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False + ) + return dialog_tokens diff --git a/mft_peft_hf/src/model/code_llama/tokenization_llama.py b/mftcoder_accelerate/src/model/code_llama/tokenization_llama.py similarity index 92% rename from mft_peft_hf/src/model/code_llama/tokenization_llama.py rename to mftcoder_accelerate/src/model/code_llama/tokenization_llama.py index fb2cfce..98360d3 100644 --- a/mft_peft_hf/src/model/code_llama/tokenization_llama.py +++ b/mftcoder_accelerate/src/model/code_llama/tokenization_llama.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for LLaMA.""" +"""4.33.1 Tokenization classes for LLaMA.""" import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -141,7 +141,7 @@ def __init__( logger.warning_once( f"You are using the default legacy behaviour of the {self.__class__}. If you see this, DO NOT PANIC! This is" " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." - " If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it" + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" " means, and thouroughly read the reason why this was added as explained in" " https://github.com/huggingface/transformers/pull/24565" ) @@ -154,19 +154,25 @@ def __init__( self.use_default_system_prompt = use_default_system_prompt self.sp_model = self.get_spm_processor() - self.unk_token_length = len(self.sp_model.encode(str(self.unk_token))) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor def get_spm_processor(self): tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + with open(self.vocab_file, "rb") as f: sp_model = f.read() - model_pb2 = import_protobuf() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") model = model_pb2.ModelProto.FromString(sp_model) - if not self.legacy: - normalizer_spec = model_pb2.NormalizerSpec() - normalizer_spec.add_dummy_prefix = False - model.normalizer_spec.MergeFrom(normalizer_spec) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) sp_model = model.SerializeToString() tokenizer.LoadFromSerializedProto(sp_model) return tokenizer @@ -194,18 +200,17 @@ def get_vocab(self): return vocab # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize - def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: """ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the first token is special. """ - if self.legacy: + if self.legacy or len(text) == 0: return super().tokenize(text, **kwargs) - if len(text) > 0: - tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) + tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) - if tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @@ -220,13 +225,14 @@ def _tokenize(self, text, **kwargs): `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. """ - if self.legacy: - return self.sp_model.encode(text, out_type=str) - - unk_token_length = len(self.sp_model.encode(str(self.unk_token))) - text = self.unk_token + text tokens = self.sp_model.encode(text, out_type=str) - return tokens[unk_token_length:] + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" diff --git a/mft_peft_hf/src/model/code_llama/tokenization_llama_fast.py b/mftcoder_accelerate/src/model/code_llama/tokenization_llama_fast.py similarity index 98% rename from mft_peft_hf/src/model/code_llama/tokenization_llama_fast.py rename to mftcoder_accelerate/src/model/code_llama/tokenization_llama_fast.py index f7ad2ec..53a4b0b 100644 --- a/mft_peft_hf/src/model/code_llama/tokenization_llama_fast.py +++ b/mftcoder_accelerate/src/model/code_llama/tokenization_llama_fast.py @@ -128,7 +128,10 @@ def __init__( self.update_post_processor() self.use_default_system_prompt = use_default_system_prompt self.vocab_file = vocab_file - self.can_save_slow_tokenizer = False if not self.vocab_file else True + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False def update_post_processor(self): """ diff --git a/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py new file mode 100644 index 0000000..82e0f5d --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py @@ -0,0 +1,206 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV2Model, DeepseekV2Config + + >>> # Initializing a Deepseek-V2 style configuration + >>> configuration = DeepseekV2Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size = 1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts = None, + n_routed_experts = None, + ep_size = 1, + routed_scaling_factor = 1.0, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'gready', + n_group = None, + topk_group = None, + num_experts_per_tok = None, + moe_layer_freq = 1, + first_k_dense_replace = 0, + norm_topk_prob = False, + scoring_func = 'softmax', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py new file mode 100644 index 0000000..d1d5e88 --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py @@ -0,0 +1,1925 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV2Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV2Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) + + +class DeepseekV2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV2MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == "group_limited_greedy": + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class DeepseekV2MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + # save dtype before computation + input_dtype = hidden_states.dtype + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + # keep dtype same after moe forward + return y.to(input_dtype) + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 +class DeepseekV2FlashAttention2(DeepseekV2Attention): + """ + DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + # print(f"dtype of hidden_states: {hidden_states.dtype}") + # print(f"dtype of q_proj: {self.q_proj.weight.dtype}") + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV2Attention, + "flash_attention_2": DeepseekV2FlashAttention2, +} + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV2MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV2MLP(config) + ) + self.input_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + # print(f"1. dtype of residual: {residual.dtype}") + + hidden_states = self.input_layernorm(hidden_states) + # print(f"2. dtype of hidden_states before attn: {hidden_states.dtype}") + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + # print(f"3. dtype of hidden_states after attn: {hidden_states.dtype}") + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + # print(f"4. dtype of hidden_states after post layernorm: {hidden_states.dtype}") + hidden_states = self.mlp(hidden_states) + # print(f"5. dtype of hidden_states after mlp: {hidden_states.dtype}") + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2PreTrainedModel(PreTrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2Model(DeepseekV2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py new file mode 100644 index 0000000..d243771 --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py @@ -0,0 +1,38 @@ +from typing import List, Optional, Union + + +from transformers.models.llama import LlamaTokenizerFast + + +class DeepseekTokenizerFast(LlamaTokenizerFast): + + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + if isinstance(ids, int): + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + token = self._tokenizer.id_to_token(index) + tokens.append(token if token is not None else "") + return tokens + + def _convert_id_to_token(self, index: int) -> Optional[str]: + token = self._tokenizer.id_to_token(int(index)) + return token if token is not None else "" diff --git a/mft_peft_hf/src/model/gpt_bigcode/__init__.py b/mftcoder_accelerate/src/model/gpt_bigcode/__init__.py similarity index 100% rename from mft_peft_hf/src/model/gpt_bigcode/__init__.py rename to mftcoder_accelerate/src/model/gpt_bigcode/__init__.py diff --git a/mft_peft_hf/src/model/gpt_bigcode/configuration_gpt_bigcode.py b/mftcoder_accelerate/src/model/gpt_bigcode/configuration_gpt_bigcode.py similarity index 100% rename from mft_peft_hf/src/model/gpt_bigcode/configuration_gpt_bigcode.py rename to mftcoder_accelerate/src/model/gpt_bigcode/configuration_gpt_bigcode.py diff --git a/mft_peft_hf/src/model/gpt_bigcode/modeling_gpt_bigcode.py b/mftcoder_accelerate/src/model/gpt_bigcode/modeling_gpt_bigcode.py similarity index 100% rename from mft_peft_hf/src/model/gpt_bigcode/modeling_gpt_bigcode.py rename to mftcoder_accelerate/src/model/gpt_bigcode/modeling_gpt_bigcode.py diff --git a/mft_atorch/model/gpt_neox/__init__.py b/mftcoder_accelerate/src/model/gpt_neox/__init__.py similarity index 100% rename from mft_atorch/model/gpt_neox/__init__.py rename to mftcoder_accelerate/src/model/gpt_neox/__init__.py diff --git a/mft_peft_hf/src/model/gpt_neox/config.json b/mftcoder_accelerate/src/model/gpt_neox/config.json similarity index 100% rename from mft_peft_hf/src/model/gpt_neox/config.json rename to mftcoder_accelerate/src/model/gpt_neox/config.json diff --git a/mft_peft_hf/src/model/gpt_neox/configuration_gpt_neox.py b/mftcoder_accelerate/src/model/gpt_neox/configuration_gpt_neox.py similarity index 100% rename from mft_peft_hf/src/model/gpt_neox/configuration_gpt_neox.py rename to mftcoder_accelerate/src/model/gpt_neox/configuration_gpt_neox.py diff --git a/mft_atorch/model/gpt_neox/generation_config.json b/mftcoder_accelerate/src/model/gpt_neox/generation_config.json similarity index 100% rename from mft_atorch/model/gpt_neox/generation_config.json rename to mftcoder_accelerate/src/model/gpt_neox/generation_config.json diff --git a/mft_peft_hf/src/model/gpt_neox/modeling_gpt_neox.py b/mftcoder_accelerate/src/model/gpt_neox/modeling_gpt_neox.py similarity index 100% rename from mft_peft_hf/src/model/gpt_neox/modeling_gpt_neox.py rename to mftcoder_accelerate/src/model/gpt_neox/modeling_gpt_neox.py diff --git a/mft_peft_hf/src/model/gpt_neox/tokenization_gpt_neox_fast.py b/mftcoder_accelerate/src/model/gpt_neox/tokenization_gpt_neox_fast.py similarity index 100% rename from mft_peft_hf/src/model/gpt_neox/tokenization_gpt_neox_fast.py rename to mftcoder_accelerate/src/model/gpt_neox/tokenization_gpt_neox_fast.py diff --git a/mftcoder_accelerate/src/model/phi/configuration_mixformer_sequential.py b/mftcoder_accelerate/src/model/phi/configuration_mixformer_sequential.py new file mode 100644 index 0000000..8cc2d51 --- /dev/null +++ b/mftcoder_accelerate/src/model/phi/configuration_mixformer_sequential.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import math +from typing import Any, Dict, List, Optional, Union + +from transformers import PretrainedConfig + + +class MixFormerSequentialConfig(PretrainedConfig): + """MixFormer (sequential for DeepSpeed) configuration.""" + + model_type = "mixformer-sequential" + + attribute_map = { + "max_position_embeddings": "n_positions", + "hidden_size": "n_embd", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size: Optional[int] = 50304, + n_positions: Optional[int] = 2048, + n_embd: Optional[int] = 1024, + n_layer: Optional[int] = 20, + n_inner: Optional[int] = None, + n_head: Optional[int] = 16, + rotary_dim: Optional[int] = 32, + activation_function: Optional[str] = "gelu_new", + embd_pdrop: Optional[float] = 0.0, + resid_pdrop: Optional[float] = 0.0, + layer_norm_epsilon: Optional[float] = 1e-5, + initializer_range: Optional[float] = 0.02, + tie_word_embeddings: Optional[bool] = False, + pad_vocab_size_multiple: Optional[int] = 64, + **kwargs + ) -> None: + self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_inner = n_inner + self.n_head = n_head + self.rotary_dim = min(rotary_dim, n_embd // n_head) + self.activation_function = activation_function + self.embd_pdrop = embd_pdrop + self.resid_pdrop = resid_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/mftcoder_accelerate/src/model/phi/modeling_mixformer_sequential.py b/mftcoder_accelerate/src/model/phi/modeling_mixformer_sequential.py new file mode 100644 index 0000000..4b26a55 --- /dev/null +++ b/mftcoder_accelerate/src/model/phi/modeling_mixformer_sequential.py @@ -0,0 +1,764 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# +# BSD 3-Clause License +# +# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import math +import copy +from typing import Any, Dict, Optional, Tuple, Union +from dataclasses import dataclass, field + +import torch +import torch.nn as nn + +from einops import rearrange +from transformers.activations import ACT2FN +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .configuration_mixformer_sequential import MixFormerSequentialConfig + +import xformers.ops + +@dataclass +class InferenceParams: + """Inference parameters passed to model to efficiently calculate + and store context during inference. + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. + Args: + max_sequence_len: Maximum sequence length. + max_batch_size: Maximum batch size. + sequence_len_offset: Sequence length offset. + batch_size_offset: Batch size offset. + key_value_memory_dict: Key value memory dictionary. + fused_ft_kernel: Whether to use fused kernel for fast inference. + lengths_per_sample: Lengths per sample. + """ + + max_sequence_len: int = field(metadata={"help": "Maximum sequence length."}) + + max_batch_size: int = field(metadata={"help": "Maximum batch size."}) + + sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) + + batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) + + key_value_memory_dict: Dict[str, Any] = field( + default_factory=dict, metadata={"help": "Key value memory dictionary."} + ) + + fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."}) + + lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."}) + + +class Embedding(nn.Module): + """Token embedding with dropout.""" + + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() + + self.wte = nn.Embedding(config.vocab_size, config.n_embd) + self.drop = nn.Dropout(config.embd_pdrop) + + def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.wte(input_ids) + hidden_states = self.drop(hidden_states) + + return hidden_states + + +class RotaryEmbedding(nn.Module): + """Rotary embeddings. + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py. + + """ + + def __init__( + self, + dim: int, + base: int = 10000, + scale_base: Optional[float] = None, + device: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + + if scale_base is not None: + raise NotImplementedError + + # Generate and save the inverse frequency buffer (non-trainable) + self.dim = dim + self.base = base + self.scale_base = scale_base + self.device = device + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq) + + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None: + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + seqlen = x.shape[1] + seqlen_offset + + # Re-generate the inverse frequency buffer if it's not fp32 + # (for instance if model.half() was called) + if self.inv_freq.dtype != "torch.float32": + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim) + ) + + if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=x.device, dtype=torch.float32) + + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32)) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(x.dtype) + self._sin_cached = torch.sin(freqs).to(x.dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) + + def _apply_rotary_emb_qkv( + self, + qkv: torch.FloatTensor, + sin: torch.FloatTensor, + cos: torch.FloatTensor, + sin_k: Optional[torch.FloatTensor] = None, + cos_k: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + _, seqlen, three, _, headdim = qkv.shape + assert three == 3 + + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) + + q_rot = qkv[:, :, 0, :, :rotary_dim] + q_pass = qkv[:, :, 0, :, rotary_dim:] + + k_rot = qkv[:, :, 1, :, :rotary_dim] + k_pass = qkv[:, :, 1, :, rotary_dim:] + + # Splits the queries and keys in half + q1, q2 = q_rot.chunk(2, dim=-1) + k1, k2 = k_rot.chunk(2, dim=-1) + c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d") + + # Casts to fp32 are necessary to prevent fp16 overflow issues + q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]] + + # Computes the new keys and queries, recasting to original dtype + q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) + k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) + + return torch.cat( + [ + torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), + torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), + qkv[:, :, 2:3, :, :], + ], + axis=2, + ) + + def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + # `qkv` is of shape (batch, seqlen, 3, nheads, headdim) + self._update_cos_sin_cache(qkv, seqlen_offset) + return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]) + + +class MLP(nn.Module): + """Multi-Layer Perceptron. + Reference: + Attention Is All You Need. + https://arxiv.org/pdf/1706.03762.pdf. + """ + + def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None: + super().__init__() + + act_fn = config.activation_function if act_fn is None else act_fn + assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." + + n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner + n_inner = n_inner if n_inner is not None else 4 * config.n_embd + + self.fc1 = nn.Linear(config.n_embd, n_inner) + self.fc2 = nn.Linear(n_inner, config.n_embd) + self.act = ACT2FN[act_fn] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + + return hidden_states + + +class SelfAttention(nn.Module): + """Self-attention layer (compatible with PyTorch). + + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. + """ + + def __init__( + self, + causal: bool = True, + softmax_scale: Optional[float] = None, + attention_dropout: float = 0.0, + ) -> None: + super().__init__() + + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward( + self, + qkv: torch.FloatTensor, + causal: bool = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + causal = self.causal if causal is None else causal + batch_size, seq_len = qkv.shape[0], qkv.shape[1] + q, k, v = qkv.unbind(dim=2) + + # flash attention + output = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(), + op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp) + + # softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + # scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + + # if attention_mask is not None: + # padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device) + # padding_mask.masked_fill_(attention_mask, 0.0) + + # scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + + # if causal: + # causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1) + # scores = scores + causal_mask.to(dtype=scores.dtype) + + # attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + # attention = self.drop(attention) + + # output = torch.einsum("bhts,bshd->bthd", attention, v) + + return output + + +class CrossAttention(nn.Module): + """Cross-attention layer (compatible with PyTorch). + + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. + + """ + + def __init__( + self, + causal: bool = True, + softmax_scale: Optional[float] = None, + attention_dropout: float = 0.0, + ) -> None: + super().__init__() + + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward( + self, + q: torch.FloatTensor, + kv: torch.FloatTensor, + causal: bool = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + causal = self.causal if causal is None else causal + batch_size, seq_len_q = q.shape[0], q.shape[1] + assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] + + seq_len_k = kv.shape[1] + k, v = kv.unbind(dim=2) + + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + + if attention_mask is not None: + padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(attention_mask, 0.0) + + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + + if causal: + causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1) + scores = scores + causal_mask.to(dtype=scores.dtype) + + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention = self.drop(attention) + + output = torch.einsum("bhts,bshd->bthd", attention, v) + + return output + + +def find_mha_dims( + config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None +) -> Tuple[int, int]: + """Validate and return the number of heads and head dimension for multi-head attention. + Args: + config: Model configuration. + n_head: Number of heads. + head_dim: Head dimension. + Returns: + Number of heads and head dimension. + """ + + assert all( + hasattr(config, attr) for attr in ["n_embd", "n_head"] + ), "`config` must have `n_embd` and `n_head` attributes." + + if head_dim is None: + assert ( + config.n_embd % config.n_head == 0 + ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})." + + if n_head is None and head_dim is None: + head_dim = config.n_embd // config.n_head + n_head = config.n_head + elif n_head is None or head_dim is None: + raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") + + return n_head, head_dim + + +def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor: + """Update the key-value cache for inference. + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. + Args: + kv: Key-value tensor. + inference_params: Inference parameters. + layer_idx: Layer index. + Returns: + Updated key-value tensor. + """ + + num_heads, head_dim = kv.shape[-2:] + + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_sequence_len, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + + return kv + + assert inference_params.sequence_len_offset == 0 + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + + packsize = 4 if kv.dtype == torch.float32 else 8 + + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") + + return kv + + +class MHA(nn.Module): + """Multi-head attention layer.""" + + def __init__( + self, + config: PretrainedConfig, + dtype: Optional[torch.dtype] = None, + device: Optional[str] = None, + rotary_dim: Optional[int] = None, + rotary_emb_scale_base: Optional[float] = None, + n_head: Optional[int] = None, + head_dim: Optional[int] = None, + bias: bool = True, + causal: bool = True, + softmax_scale: Optional[float] = None, + dropout: float = 0.0, + layer_idx: Optional[int] = None, + return_residual: bool = False, + checkpointing: bool = False, + ) -> None: + super().__init__() + + # Rotary embedding + self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) + if self.rotary_emb_dim > 0: + rotary_kwargs = {"device": device} + if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: + rotary_kwargs["scale_base"] = rotary_emb_scale_base + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) + + # MLP + self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim) + op_size = self.n_head * self.head_dim + hidden_size = config.n_embd + + self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype) + self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype) + + # Attention + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + + self.layer_idx = layer_idx + self.return_residual = return_residual + # self.checkpointing = checkpointing + self.checkpointing = True + + def forward( + self, + x: torch.FloatTensor, + past_key_values: Optional[InferenceParams] = None, + attention_mask: Optional[torch.BoolTensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + max_seqlen: Optional[int] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + qkv = self.Wqkv(x) + qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) + + seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0 + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) + + if past_key_values is not None: + kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx) + + if attention_mask is not None: + attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask + attention_mask = attention_mask.bool().to(qkv.device) + + attention_kwargs = {"attention_mask": attention_mask} + + if past_key_values is None or seqlen_offset == 0: + if self.checkpointing: + # attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs) + attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv) + else: + attn_output = self.inner_attn(qkv, **attention_kwargs) + else: + q = qkv[:, :, 0] + causal = None if past_key_values.sequence_len_offset == 0 else False + attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs) + + output = rearrange(attn_output, "... h d -> ... (h d)") + output = self.out_proj(output) + + return output if not self.return_residual else (output, x) + + +class ParallelBlock(nn.Module): + """Parallel block. + This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). + """ + + def __init__( + self, + config: PretrainedConfig, + block_idx: Optional[int] = None, + checkpointing: bool = False, + ) -> None: + super().__init__() + + self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.block_idx = block_idx + + self.mixer = MHA(config, layer_idx=block_idx, checkpointing=checkpointing) + self.mlp = MLP(config) + + def forward( + self, + hidden_states: torch.FloatTensor, + past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.ln(hidden_states) + + attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask) + if isinstance(attn_outputs, tuple): + attn_outputs = attn_outputs[0] + + attn_outputs = self.resid_dropout(attn_outputs) + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + + hidden_states = attn_outputs + feed_forward_hidden_states + residual + + return hidden_states + + +class CausalLMHead(nn.Module): + """Causal Language Modeling head. + Reference: + Improving Language Understanding by Generative Pre-Training. + https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. + """ + + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() + + self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.linear = nn.Linear(config.n_embd, config.vocab_size) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_states = self.ln(hidden_states) + logits = self.linear(hidden_states).to(torch.float32) + + return logits + + +class CausalLMLoss(nn.Module): + """Causal Language Modeling loss. + Reference: + Improving Language Understanding by Generative Pre-Training. + https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. + """ + + def __init__(self, shift_labels: bool = True) -> None: + super().__init__() + + self.shift_labels = shift_labels + self.loss_fct = nn.CrossEntropyLoss() + + def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor: + if self.shift_labels: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + + loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + return loss + + +class MixFormerSequentialPreTrainedModel(PreTrainedModel): + """MixFormer (sequential for DeepSpeed) pre-trained model.""" + + config_class = MixFormerSequentialConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + + def __init__(self, *inputs, **kwargs) -> None: + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, (nn.Linear,)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> Dict[str, Any]: + if attention_mask is not None and torch.any(~attention_mask.bool()): + total_seq_len = torch.sum(attention_mask, dim=1) + max_seq_len = torch.max(total_seq_len) + + total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1) + cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32) + attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item()) + else: + attention_mask = None + + if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): + past_key_values = InferenceParams( + max_batch_size=input_ids.shape[0], + max_sequence_len=self.config.n_positions, + sequence_len_offset=0, + batch_size_offset=0, + fused_ft_kernel=False, + key_value_memory_dict={}, + ) + else: + # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` + past_key_values.sequence_len_offset = len(input_ids[0]) - 1 + input_ids = input_ids[:, -1].unsqueeze(-1) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + } + + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, PreTrainedModel): + module.gradient_checkpointing = value + + +class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): + """MixFormer (sequential for DeepSpeed) for Causal Language Modeling.""" + + _keys_to_ignore_on_load_missing = [""] + _keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] + _no_split_modules = ["ParallelBlock"] + + def __init__(self, config: MixFormerSequentialConfig) -> None: + super().__init__(config) + + modules = [Embedding(config)] + modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)] + modules.append(CausalLMHead(config)) + + self.layers = nn.Sequential(*modules) + self.loss = CausalLMLoss() + + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.layers[0].wte + + def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: + self.layers[0].wte = new_embeddings + + def get_output_embeddings(self) -> nn.Linear: + return self.layers[-1].linear + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.layers[-1].linear = new_embeddings + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, + attention_mask: Optional[torch.BoolTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + if attention_mask is not None and self.training: + print("`attention_mask` is not supported during training. Using it might lead to unexpected results.") + + if past_key_values is None and attention_mask is None: + lm_logits = self.layers(input_ids) + else: + hidden_layer = self.layers[0](input_ids) + for module in self.layers[1:-1]: + hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask) + lm_logits = self.layers[-1](hidden_layer) + + loss = None + if labels is not None: + loss = self.loss(lm_logits, labels) + + return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values) diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp new file mode 100644 index 0000000..8458a9b --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp @@ -0,0 +1,198 @@ +#include +#include +#include + +// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_256.cpp +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant8matmul_batched_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_column_compression_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4matmul_batched_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4matmul_batched_column_compression_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched_column_compression( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_old_cuda(vec, mat, mul, scales, zeros); +} + + +void vecquant4matmul_batched_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_old_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_column_compression_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4matmul_batched_column_compression_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched_column_compression_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros); +} + + + +void vecquant8matmul_batched_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_faster_cuda(vec, mat, mul, scales, zeros); +} + + +void vecquant8matmul_batched_faster_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_faster_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_faster_old_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_column_compression_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_faster_cuda(vec, mat, mul, scales, zeros); +} + + +void vecquant8matmul_batched_column_compression_faster_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression_faster_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_faster_old_cuda(vec, mat, mul, scales, zeros); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched", &vecquant8matmul_batched, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_old", &vecquant8matmul_batched_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_faster", &vecquant8matmul_batched_faster, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_faster_old", &vecquant8matmul_batched_faster_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant4matmul_batched_old", &vecquant4matmul_batched_old, "Vector 4-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_column_compression", &vecquant8matmul_batched_column_compression, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant8matmul_batched_column_compression_old", &vecquant8matmul_batched_column_compression_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant8matmul_batched_column_compression_faster", &vecquant8matmul_batched_column_compression_faster, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant8matmul_batched_column_compression_faster_old", &vecquant8matmul_batched_column_compression_faster_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant4matmul_batched_column_compression_old", &vecquant4matmul_batched_column_compression_old, "Vector old 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant4matmul_batched", &vecquant4matmul_batched, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant4matmul_batched_column_compression", &vecquant4matmul_batched_column_compression, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); +} diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu new file mode 100644 index 0000000..b7932cd --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu @@ -0,0 +1,1708 @@ +#define _CRT_SECURE_NO_WARNINGS +#include +#include +#include +#include +#include +#include + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) +// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu +__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { + unsigned int *address_as_ui = reinterpret_cast (reinterpret_cast (address) - (reinterpret_cast (address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short hsum = reinterpret_cast (address) & 2 ? (old >> 16) : (old & 0xffff); + hsum += val; + old = reinterpret_cast (address) & 2 + ? (old & 0xffff) | (hsum << 16) + : (old & 0xffff0000) | hsum; + old = atomicCAS(address_as_ui, assumed, old); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); +} +__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) { + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} +#endif + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +template +__global__ void VecQuant8BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant4BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + + + +template +__global__ void VecQuant8BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + +__global__ void VecQuant8BatchMatMulKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + + + +__global__ void VecQuant8BatchMatMulKernel_faster_old( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width +); + + +template +__global__ void VecQuant4BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + + +__global__ void VecQuant8BatchMatMulKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width +); + + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +const int BLOCKWIDTH = 128; +const int BLOCKHEIGHT8 = 32; +const int BLOCKHEIGHT4 = 16; +const int BLOCKHEIGHT_OLD4 = 128; +//const int BLOCKHEIGHT_OLD8 = 128; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast (&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast (&i); +} + +void vecquant8matmul_batched_column_compression_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3) * 4; + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_cuda", ([&] { + VecQuant8BatchMatMulColumnCompressionKernel<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width / 4; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + int i_w = (w / 4); + int w_bit = (w % 4) * 8; + + int w_index = (batch_shift * height + h + k) * width / 4 + i_w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xFF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + +void vecquant8matmul_batched_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_cuda", ([&] { + VecQuant8BatchMatMulKernel<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); + +} + +template +__global__ void VecQuant8BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT8 + int h = BLOCKHEIGHT8 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + // if (i >= width * height) { + // return; + // } + int k; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero; + if (zero_width == width) { + zero = zeros[batch_shift * width + w]; + } else { + zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + } + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xFF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), g_idx.data (), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 4; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); + + weight[k] = scale * (w_tmp - zero); + } + + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + + + +void vecquant4matmul_batched_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_cuda", ([&] { + VecQuant4BatchMatMulKernel<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT4 + int h = BLOCKHEIGHT4 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + int k; + scalar_t w_tmp; + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){ + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero; + if (zero_width == width) { + zero = zeros[batch_shift * width + w]; + } else { + zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xF)); + } + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + +void vecquant4matmul_batched_column_compression_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3) * 8; + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_cuda", ([&] { + VecQuant4BatchMatMulColumnCompressionKernel<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width / 8; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + int i_w = (w / 8); + int w_bit = (w % 8) * 4; + + int w_index = (batch_shift * height + h + k) * width / 8 + i_w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + +void vecquant8matmul_batched_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_old_cuda", ([&] { + VecQuant8BatchMatMulKernel_old<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); +} + + +template +__global__ void VecQuant8BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT8 + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + int k_w = k; + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero = zeros[batch_shift * width + w]; + w_tmp = as_unsigned(mat[w_index]); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + +void vecquant8matmul_batched_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulKernel_faster<< >>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, vec_height, height, width, zero_width + ); +} + + + +__global__ void VecQuant8BatchMatMulKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + //int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + int h = BLOCKWIDTH * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ float blockvec[BLOCKWIDTH]; + int i = width * h + w; + int k; + float w_tmp; + + float weight[BLOCKWIDTH]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + int k_w = k; + int w_index = batch_shift * height * width + i + (k_w * width); + float scale = __half2float(scales[batch_shift * width + w]); + float zero = __half2float(zeros[batch_shift * width + w]); + w_tmp = as_unsigned(mat[w_index]); + weight[k] = scale *(w_tmp-zero); + } + + float res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = __half2float(vec[vec_index]); + } else { + blockvec[tid] = 0; + } + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + float temp_res = weight[k]*blockvec[k]; + res += temp_res; + } + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} + + + + +void vecquant8matmul_batched_column_compression_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulColumnCompressionKernel_faster<< >>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, height, width + ); + +} + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + //int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + int h = BLOCKWIDTH * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ float blockvec[BLOCKWIDTH]; + int k; + float w_tmp; + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH; ++k){ + int w_index = (batch_shift * height + h + k) * width + w; + float scale = __half2float(scales[batch_shift * height + h + k]); + float zero = __half2float(zeros[batch_shift * height + h + k]); + w_tmp = mat[w_index]; + weight[k] = scale * (w_tmp-zero); + } + + float res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = __half2float(vec[vec_index]); + } else { + blockvec[tid] = 0; + } + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k]*blockvec[k]; + } + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} + + + +void vecquant8matmul_batched_column_compression_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_column_compression_old_cuda", ([&] { + VecQuant8BatchMatMulColumnCompressionKernel_old<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + int w_index = (batch_shift * height + h + k) * width + w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = mat[w_index]; + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + +void vecquant4matmul_batched_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_old_cuda", ([&] { + VecQuant4BatchMatMulKernel_old<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT_OLD4 + int h = BLOCKHEIGHT_OLD4 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){ + int k_w = (k / 2); + int k_bit = (k % 2) * 4; + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero = zeros[batch_shift * width + w]; + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + + + +void vecquant4matmul_batched_column_compression_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_column_compression_old_cuda", ([&] { + VecQuant4BatchMatMulColumnCompressionKernel_old<< >>( + vec.data (), mat.data (), mul.data (), + scales.data (), zeros.data (), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKHEIGHT_OLD4 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){ + int k_w = (k / 2); + int k_bit = (k % 2) * 4; + int w_index = (batch_shift * height + h + k) * width + k_w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + + + +void vecquant8matmul_batched_faster_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulKernel_faster_old<< >>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, vec_height, height, width + ); +} + + +__global__ void VecQuant8BatchMatMulKernel_faster_old( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + const int BLOCKWIDTH_half = BLOCKWIDTH/2; + + int h = BLOCKWIDTH * blockIdx.x; //head_dim, dim=-1 + int w = BLOCKWIDTH * blockIdx.y + tid; //seq-len, +0-256 ,dim=-2 + /* + if (w >= width && tid >= vec_height) { + return; + } + */ + __shared__ half blockvec[BLOCKWIDTH]; //256 + int i = width * h + w; + int k; + + half w_tmp1 = __float2half(0); + half w_tmp2 = __float2half(0); + + half2 weight[BLOCKWIDTH_half]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + //int zero_index = batch_shift; + for (k = 0; k < BLOCKWIDTH_half; ++k){ + int w_index1 = batch_shift * height * width + i + (2 * k * width); // [batch,head,h+k, w] + int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width); + int zero_index = batch_shift * width + w; // [batch,head, w] + if (w_index1 >= weight_total || w >= width || (2 * k + h) >= height) { + weight[k] = __float2half2_rn(0); + } else { + float zero_f=__half2float(zeros[zero_index]); + float scale_f= __half2float(scales[zero_index]); + if (w_index2 >= weight_total){ + w_tmp1 = __float2half((as_unsigned(mat[w_index1]) -zero_f)*scale_f); + w_tmp2 = __float2half(0); + weight[k] = __halves2half2(w_tmp1,w_tmp2); + //printf("zero_index is %d w is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,w,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k])); + }else{ + w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1])); + w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2])); + + //weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero,zero)),__halves2half2(scale,scale)); + weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f))); + //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k])); + } + } + } + + + for (int vr = 0; vr < vec_row; ++vr){ + float res=0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + int out_index = (batch_shift * vec_row + vr) * width + w; + if (vec_index < input_total) { + //blockvec[tid] = __half2float(vec[vec_index]);// [batch, head, vr, tid(seq_len dim+)] + blockvec[tid] = vec[vec_index]; + //printf("width is %d height is %d h is %d w is %d vec_index is %d out_index is %d vec_row is %d vec_height is %d,vr is %d tid is %d blockvec is %f\n",width,height, h,w,vec_index,out_index,vec_row,vec_height,vr,tid,blockvec[tid]); + } else { + blockvec[tid] = __float2half(0); + } + __syncthreads(); + if (out_index < out_total) { + for (k = 0; k < BLOCKWIDTH_half; ++k){ + half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1])); + res += __low2float(res2) + __high2float(res2); + } + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} + + +void vecquant8matmul_batched_column_compression_faster_old_cuda( + torch::Tensor vec, // [batch,heads, seq_q, seq_v] + torch::Tensor mat, // [batch,heads, seq_v, head_dim] + torch::Tensor mul, // [batch,heads, seq_q,head_dim] + torch::Tensor scales, // [batch,heads, head_dim] + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); //ql + int height = mat.size(2); //vl + int width = mat.size(3); //head_dim + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulColumnCompressionKernel_faster_old<< >>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, height, width + ); + +} + + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old( + const half* __restrict__ vec, // [batch,heads, seq_q, seq_v] + const uint8_t* __restrict__ mat, // [batch,heads, seq_v, head_dim] + half* __restrict__ mul, // [batch,heads, seq_q,head_dim] + const half* __restrict__ scales, // [batch,heads, seq_v] + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, //seq_q + int height, //seq_v + int width //head_dim +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + int h = BLOCKWIDTH * blockIdx.x; // vl + int w = BLOCKWIDTH * blockIdx.y + tid; //head_dim + block + if (w >= width && tid >= height) { + return; + } + __shared__ half blockvec[BLOCKWIDTH]; + int k; + half w_tmp1 = __float2half(0); + half w_tmp2 = __float2half(0); + int i = width * h + w; + const int BLOCKWIDTH_half = BLOCKWIDTH/2; + half2 weight[BLOCKWIDTH_half]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + //int zero_index = batch_shift; + for (k = 0; k < BLOCKWIDTH_half; ++k){ + int w_index1 = batch_shift * height * width + i + (2 * k) * width; // [batch,head, h+k, w] + int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width); + int zero_index1 = batch_shift * height + h + 2*k; // [batch,head, w] + int zero_index2 = batch_shift * height + h + 2*k+1; // [batch,head, w] + + if (w_index1 >= weight_total || (2 * k + h)>=height) { + weight[k]=__float2half2_rn(0); + } else{ + //int zero_index = batch_shift + h; // [batch,head, w] + //float scale_f1 = __half2float(scales[zero_index1]); + //float zero_f1 = __half2float(zeros[zero_index1]); + if (w_index2>=weight_total){ + w_tmp1 = __float2half((as_unsigned(mat[w_index1]) - __half2float(zeros[zero_index1]))* __half2float(scales[zero_index1])); + w_tmp2 = __float2half(0); + weight[k] = __halves2half2(w_tmp1,w_tmp2); + //printf("zero_index is %d k is %d w is %d head is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,k,w,head,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k])); + }else{ + w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1])); + w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2])); + half zero1=zeros[zero_index1]; + half zero2=zeros[zero_index2]; + half scale1=scales[zero_index1]; + half scale2=scales[zero_index2]; + weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero1,zero2)),__halves2half2(scale1,scale2)); + //weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f))); + //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k])); + } + } + } + + + for (int vr = 0; vr < vec_row; ++vr){ + float res=0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + int out_index = (batch_shift * vec_row + vr) * width + w; + + if (vec_index < input_total) { + //blockvec[tid] = __half2float(vec[vec_index]); + blockvec[tid] = vec[vec_index]; + //printf("vec_index is %d out_index is %d vec_row is %d ,vr is %d tid is %d blockvec is %f\n",vec_index,out_index,vec_row,vr,tid,blockvec[tid]); + } else { + blockvec[tid] = __float2half(0); + //blockvec[tid] = 0; + } + __syncthreads(); + if (out_index < out_total) { + for (k = 0; k < BLOCKWIDTH_half; ++k){ + half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1])); + res += __low2float(res2) + __high2float(res2); + } + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} diff --git a/mft_peft_hf/src/model/qwen/configuration_qwen.py b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py similarity index 53% rename from mft_peft_hf/src/model/qwen/configuration_qwen.py rename to mftcoder_accelerate/src/model/qwen/configuration_qwen.py index 502ef6c..f8fe2cb 100644 --- a/mft_peft_hf/src/model/qwen/configuration_qwen.py +++ b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py @@ -9,61 +9,49 @@ class QWenConfig(PretrainedConfig): model_type = "qwen" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = { - "hidden_size": "n_embd", - "num_attention_heads": "n_head", - "max_position_embeddings": "n_positions", - "num_hidden_layers": "n_layer", - } def __init__( self, - vocab_size=151851, - n_embd=4096, - n_layer=32, - n_head=32, - n_inner=None, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=1e-5, + vocab_size=151936, + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + emb_dropout_prob=0.0, + attn_dropout_prob=0.0, + layer_norm_epsilon=1e-6, initializer_range=0.02, + max_position_embeddings=8192, scale_attn_weights=True, use_cache=True, - eos_token_id=151643, - apply_residual_connection_post_layernorm=False, bf16=False, fp16=False, fp32=False, kv_channels=128, rotary_pct=1.0, rotary_emb_base=10000, - use_dynamic_ntk=False, - use_logn_attn=False, - use_flash_attn=True, - ffn_hidden_size=22016, + use_dynamic_ntk=True, + use_logn_attn=True, + use_flash_attn="auto", + intermediate_size=22016, no_bias=True, tie_word_embeddings=False, + use_cache_quantization=False, + use_cache_kernel=False, + softmax_in_fp32=False, **kwargs, ): - self.eos_token_id = eos_token_id - super().__init__( - eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs - ) - self.vocab_size = vocab_size - self.n_embd = n_embd - self.n_layer = n_layer - self.n_head = n_head - self.n_inner = n_inner - self.embd_pdrop = embd_pdrop - self.attn_pdrop = attn_pdrop + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.emb_dropout_prob = emb_dropout_prob + self.attn_dropout_prob = attn_dropout_prob self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache - self.apply_residual_connection_post_layernorm = ( - apply_residual_connection_post_layernorm - ) + self.max_position_embeddings = max_position_embeddings self.bf16 = bf16 self.fp16 = fp16 self.fp32 = fp32 @@ -73,6 +61,11 @@ def __init__( self.use_dynamic_ntk = use_dynamic_ntk self.use_logn_attn = use_logn_attn self.use_flash_attn = use_flash_attn - self.ffn_hidden_size = ffn_hidden_size self.no_bias = no_bias - self.tie_word_embeddings = tie_word_embeddings + self.use_cache_quantization = use_cache_quantization + self.use_cache_kernel = use_cache_kernel + self.softmax_in_fp32 = softmax_in_fp32 + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/mftcoder_accelerate/src/model/qwen/cpp_kernels.py b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py new file mode 100644 index 0000000..d9cee70 --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py @@ -0,0 +1,55 @@ +from torch.utils import cpp_extension +import pathlib +import os +import subprocess + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") + +# Check if cuda 11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) +if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 7: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') + +# Build path +srcpath = pathlib.Path(__file__).parent.absolute() +buildpath = srcpath / 'build' +_create_build_dir(buildpath) + +def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3', ], + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag, + verbose=1 + ) + +extra_flags = [] + +cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp", + "./cache_autogptq_cuda_kernel_256.cu"] +cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags) diff --git a/mft_peft_hf/src/model/qwen/modeling_qwen.py b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py similarity index 62% rename from mft_peft_hf/src/model/qwen/modeling_qwen.py rename to mftcoder_accelerate/src/model/qwen/modeling_qwen.py index 238fe91..45c0d16 100644 --- a/mft_peft_hf/src/model/qwen/modeling_qwen.py +++ b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py @@ -3,14 +3,16 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import copy import importlib import math +import pathlib from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator import torch import torch.nn.functional as F import torch.utils.checkpoint -from torch.cuda.amp import autocast +import warnings from torch.nn import CrossEntropyLoss from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList @@ -35,6 +37,8 @@ SUPPORT_CUDA = torch.cuda.is_available() SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 +SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 + from .configuration_qwen import QWenConfig from .qwen_generation_utils import ( @@ -66,13 +70,18 @@ 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 """ +_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\ +We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained). +检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。 +""" + apply_rotary_emb_func = None rms_norm = None flash_attn_unpadded_func = None - +flash_attn_func = None def _import_flash_attn(): - global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func + global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func try: from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func apply_rotary_emb_func = __apply_rotary_emb_func @@ -93,20 +102,49 @@ def _import_flash_attn(): try: import flash_attn + _flash_attn_func = None if not hasattr(flash_attn, '__version__'): from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func else: if int(flash_attn.__version__.split(".")[0]) >= 2: + if int(flash_attn.__version__.split(".")[1]) >= 1: + from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func else: from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func flash_attn_unpadded_func = __flash_attn_unpadded_func + flash_attn_func = _flash_attn_func except ImportError: logger.warn( "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " "https://github.com/Dao-AILab/flash-attention" ) +def quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + +def dequantize_cache_torch(qdata, scale, zero): + data = scale * (qdata - zero) + return data class FlashSelfAttention(torch.nn.Module): def __init__( @@ -126,11 +164,33 @@ def __init__( self.softmax_scale = softmax_scale self.dropout_p = attention_dropout - def forward(self, q, k, v): + def unpad_input(self, hidden_states, attention_mask): + valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0) + seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + hidden_states = hidden_states[indices] + return hidden_states, indices, cu_seqlens, max_seqlen_in_batch + + def pad_input(self, hidden_states, indices, batch, seqlen): + output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device, + dtype=hidden_states.dtype) + output[indices] = hidden_states + return rearrange(output, '(b s) ... -> b s ...', b=batch) + + def forward(self, q, k, v, attention_mask=None): assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) assert all((i.is_cuda for i in (q, k, v))) batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = k.shape[1] + seqlen_out = seqlen_q + + if flash_attn_func is not None and batch_size == 1: + dropout_p = self.dropout_p if self.training else 0 + output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal) + return output + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] cu_seqlens_q = torch.arange( 0, @@ -140,13 +200,14 @@ def forward(self, q, k, v): device=q.device, ) - if self.training: - assert seqlen_k == seqlen_q - - is_causal = self.causal - cu_seqlens_k = cu_seqlens_q + if batch_size > 1 and attention_mask is not None: + k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask) + if q.size(0) == v.size(0): + q = q[indices_k] + cu_seqlens_q = cu_seqlens_k + seqlen_q = seqlen_k + v = v[indices_k] else: - is_causal = seqlen_q == seqlen_k cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, @@ -154,7 +215,15 @@ def forward(self, q, k, v): dtype=torch.int32, device=q.device, ) - self.dropout_p = 0 + + if self.training: + assert seqlen_k == seqlen_q + is_causal = self.causal + dropout_p = self.dropout_p + else: + is_causal = seqlen_q == seqlen_k + dropout_p = 0 + output = flash_attn_unpadded_func( q, k, @@ -163,30 +232,23 @@ def forward(self, q, k, v): cu_seqlens_k, seqlen_q, seqlen_k, - self.dropout_p, + dropout_p, softmax_scale=self.softmax_scale, causal=is_causal, ) - - output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k: + output = self.pad_input(output, indices_k, batch_size, seqlen_out) + else: + new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:] + output = output.view(new_shape) return output class QWenAttention(nn.Module): - def __init__(self, config, layer_number=None): + def __init__(self, config): super().__init__() - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", - torch.tril( - torch.ones((max_positions, max_positions), dtype=torch.bool) - ).view(1, 1, max_positions, max_positions), - persistent=False, - ) self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) - self.layer_number = max(1, layer_number) - self.params_dtype = config.params_dtype self.seq_length = config.seq_length self.hidden_size = config.hidden_size @@ -197,8 +259,6 @@ def __init__(self, config, layer_number=None): self.use_flash_attn = config.use_flash_attn self.scale_attn_weights = True - self.layer_idx = None - self.projection_size = config.kv_channels * config.num_attention_heads assert self.projection_size % config.num_attention_heads == 0 @@ -219,25 +279,10 @@ def __init__(self, config, layer_number=None): and not self.is_fp32 ): self.core_attention_flash = FlashSelfAttention( - causal=True, attention_dropout=config.attn_pdrop + causal=True, attention_dropout=config.attn_dropout_prob ) - self.bf16 = config.bf16 - if config.rotary_pct == 1.0: - self.rotary_ndims = None - else: - assert config.rotary_pct < 1 - self.rotary_ndims = int( - self.hidden_size_per_attention_head * config.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else self.hidden_size_per_attention_head - ) - self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) - self.use_dynamic_ntk = config.use_dynamic_ntk self.use_logn_attn = config.use_logn_attn @@ -245,104 +290,104 @@ def __init__(self, config, layer_number=None): math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768) ] - self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None] - self._ntk_cached = 1.0 - - self.attn_dropout = nn.Dropout(config.attn_pdrop) - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) + logn_tensor = torch.tensor(logn_list)[None, :, None, None] + self.register_buffer("logn_tensor", logn_tensor, persistent=False) + + self.attn_dropout = nn.Dropout(config.attn_dropout_prob) + self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False + self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False + self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False + cache_dtype = torch.float + if self.bf16: + cache_dtype=torch.bfloat16 + elif config.fp16: + cache_dtype = torch.float16 + self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) + self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) + + if config.use_cache_quantization and config.use_cache_kernel: + # pre check if the support files existing + module_root = pathlib.Path(__file__).parent + src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") + if any(not (module_root/src).is_file() for src in src_files): + warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") + self.cache_kernels = None + else: + try: + from .cpp_kernels import cache_autogptq_cuda_256 + self.cache_kernels = cache_autogptq_cuda_256 + except ImportError: + warnings.warn("Failed to import KV cache kernels.") + self.cache_kernels = None + + def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): + device = query.device + if self.use_cache_quantization: + qk, qk_scale, qk_zero = key + if self.use_cache_kernel and self.cache_kernels is not None: + shape = query.shape[:-1] + (qk.shape[-2],) + attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_faster_old( + query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), + qk.transpose(-1, -2).contiguous(), + attn_weights, + qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), + qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) + # attn_weights = attn_weights.to(query.dtype).contiguous() + else: + key = dequantize_cache_torch(qk, qk_scale, qk_zero) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + else: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], - value.size(-1) ** 0.5, - dtype=attn_weights.dtype, - device=attn_weights.device, - ) + if self.use_cache_quantization: + size_temp = value[0].size(-1) + else: + size_temp = value.size(-1) + attn_weights = attn_weights / (size_temp ** 0.5) - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[ - :, :, key_length - query_length : key_length, :key_length - ] mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def _upcast_and_reordered_attn( - self, query, key, value, attention_mask=None, head_mask=None - ): - bsz, num_heads, q_seq_len, dk = query.size() - _, _, k_seq_len, _ = key.size() - - attn_weights = torch.empty( - bsz * num_heads, - q_seq_len, - k_seq_len, - dtype=torch.float32, - device=query.device, - ) - - scale_factor = 1.0 - if self.scale_attn_weights: - scale_factor /= float(value.size(-1)) ** 0.5 - - with autocast(enabled=False): - q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( - -1, dk, k_seq_len - ) - attn_weights = torch.baddbmm( - attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor + if causal_mask is not None: + attn_weights = torch.where( + causal_mask, attn_weights.to(attn_weights.dtype), mask_value ) - attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) - - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[ - :, :, key_length - query_length : key_length, :key_length - ] - mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if self.softmax_in_fp32: + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if attn_weights.dtype != torch.float32: - raise RuntimeError( - "Error with upcasting, attn_weights does not have dtype torch.float32" - ) - attn_weights = attn_weights.type(value.dtype) + attn_weights = attn_weights.type(query.dtype) attn_weights = self.attn_dropout(attn_weights) if head_mask is not None: attn_weights = attn_weights * head_mask - attn_output = torch.matmul(attn_weights, value) + if self.use_cache_quantization: + qv, qv_scale, qv_zero = value + if self.use_cache_kernel and self.cache_kernels is not None: + shape = attn_weights.shape[:-1] + (query.shape[-1],) + attn_output = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( + attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), + qv.contiguous(), # dtype: int32 + attn_output, + qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), + qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) + if attn_output.dtype != query.dtype: + attn_output = attn_output.to(query.dtype) + attn_weights = attn_weights.to(query.dtype) + else: + value = dequantize_cache_torch(qv, qv_scale, qv_zero) + attn_output = torch.matmul(attn_weights, value) + else: + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @@ -359,6 +404,7 @@ def _merge_heads(self, tensor, num_heads, attn_head_size): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -367,64 +413,80 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ): - mixed_x_layer = self.c_attn(hidden_states) + query, key, value = mixed_x_layer.split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) - kv_seq_len = hidden_states.size()[1] - if layer_past: - # layer past[0] shape: bs * seq_len * head_num * dim - kv_seq_len += layer_past[0].shape[1] - if ( - self.use_dynamic_ntk - and kv_seq_len == hidden_states.size()[1] - and not self.training - ): - context_value = math.log(kv_seq_len / self.seq_length, 2) + 1 - ntk_alpha = 2 ** math.ceil(context_value) - 1 - ntk_alpha = max(ntk_alpha, 1) - self._ntk_cached = ntk_alpha - else: - ntk_alpha = self._ntk_cached - rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to( - hidden_states.device - ) - - if rotary_pos_emb is not None: - if isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = rotary_pos_emb - else: + if rotary_pos_emb_list is not None: + cur_len = query.shape[1] + if len(rotary_pos_emb_list) == 1: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + else: + query_list = [] + key_list = [] + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] + key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] + query = torch.cat(query_list, dim=0) + key = torch.cat(key_list, dim=0) + + if self.use_cache_quantization: + key = quantize_cache_v(key.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + value = quantize_cache_v(value.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - cur_len = query.shape[1] - q_pos_emb = q_pos_emb[:, -cur_len:, :, :] - k_pos_emb = k_pos_emb[:, -cur_len:, :, :] - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = (torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2)) + value = (torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2)) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) if use_cache: present = (key, value) else: present = None - if self.use_logn_attn and not self.training: - if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: - self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] + key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + if key_size > self.seq_length and self.use_logn_attn and not self.training: + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) if ( @@ -434,23 +496,49 @@ def forward( and query.is_cuda ): q, k, v = query, key, value - context_layer = self.core_attention_flash(q, k, v) - - context_layer = rearrange( - context_layer, "b s h d -> b s (h d)" - ).contiguous() + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) else: + key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + if query.size(1) == key_size: + causal_mask = torch.tril( + torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) + ).view(1, 1, key_size, key_size) + else: + causal_mask = None query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - attn_output, attn_weight = self._attn( - query, key, value, attention_mask, head_mask - ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + if ( + causal_mask is None + and self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and not query.is_cuda + ): + raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) + + if not self.use_cache_quantization and SUPPORT_TORCH2: + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) + if causal_mask is not None: + attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) + context_layer = self._merge_heads( + attn_output, self.num_heads, self.head_dim + ) attn_output = self.c_proj(context_layer) + outputs = (attn_output, present) if output_attentions: if ( @@ -459,6 +547,8 @@ def forward( and not self.is_fp32 ): raise ValueError("Cannot output attentions while using flash-attn") + elif not self.use_cache_quantization and SUPPORT_TORCH2: + raise ValueError("Cannot output attentions while using scaled_dot_product_attention") else: outputs += (attn_weight,) @@ -469,12 +559,12 @@ class QWenMLP(nn.Module): def __init__(self, config): super().__init__() self.w1 = nn.Linear( - config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias ) self.w2 = nn.Linear( - config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias ) - ff_dim_in = config.ffn_hidden_size // 2 + ff_dim_in = config.intermediate_size // 2 self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) def forward(self, hidden_states): @@ -486,24 +576,16 @@ def forward(self, hidden_states): class QWenBlock(nn.Module): - def __init__(self, config, layer_idx=None, num_expert=1): + def __init__(self, config): super().__init__() - self.num_expert = num_expert - self.layer_number = layer_idx - self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm - ) hidden_size = config.hidden_size - self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm - ) self.bf16 = config.bf16 self.ln_1 = RMSNorm( hidden_size, eps=config.layer_norm_epsilon, ) - self.attn = QWenAttention(config, layer_number=layer_idx) + self.attn = QWenAttention(config) self.ln_2 = RMSNorm( hidden_size, eps=config.layer_norm_epsilon, @@ -514,6 +596,7 @@ def __init__(self, config, layer_idx=None, num_expert=1): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -526,6 +609,7 @@ def forward( attn_outputs = self.attn( layernorm_output, + rotary_pos_emb_list, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, @@ -536,19 +620,12 @@ def forward( outputs = attn_outputs[1:] - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states + residual = hidden_states layernorm_input = attn_output + residual layernorm_output = self.ln_2(layernorm_input) - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - + residual = layernorm_input mlp_output = self.mlp(layernorm_output) hidden_states = residual + mlp_output @@ -566,6 +643,7 @@ class QWenPreTrainedModel(PreTrainedModel): is_parallelizable = False supports_gradient_checkpointing = True _no_split_modules = ["QWenBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -589,7 +667,7 @@ def _init_weights(self, module): mean=0.0, std=( self.config.initializer_range - / math.sqrt(2 * self.config.n_layer) + / math.sqrt(2 * self.config.num_hidden_layers) ), ) @@ -603,31 +681,40 @@ class QWenModel(QWenPreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.padded_vocab_size + self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.embed_dim = config.hidden_size + self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False - max_sequence_length = config.max_position_embeddings - self.position_embedding_type = config.pos_emb self.gradient_checkpointing = False + self.use_dynamic_ntk = config.use_dynamic_ntk + self.seq_length = config.seq_length + + self.wte = nn.Embedding(self.vocab_size, self.embed_dim) - if self.position_embedding_type == "learned": - self.wpe = nn.Embedding(max_sequence_length, self.embed_dim) - self.init_method(self.position_embeddings.weight) - self._position_embeddings_key = "position_embeddings" - self.init_method(self.position_embeddings.weight) + self.drop = nn.Dropout(config.emb_dropout_prob) + + if config.rotary_pct == 1.0: + self.rotary_ndims = None else: - self.wpe = None - self._position_embeddings_key = "" + assert config.rotary_pct < 1 + self.rotary_ndims = int( + config.kv_channels * config.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else config.kv_channels + ) + self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) - self.wte = nn.Embedding(self.vocab_size, self.embed_dim) + self.use_flash_attn = config.use_flash_attn + self.is_fp32 = not (config.bf16 or config.fp16) - self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList( [ QWenBlock( - config, - layer_idx=i, + config ) for i in range(config.num_hidden_layers) ] @@ -645,6 +732,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + def get_ntk_alpha(self, true_seq_len): + context_value = math.log(true_seq_len / self.seq_length, 2) + 1 + ntk_alpha = 2 ** math.ceil(context_value) - 1 + ntk_alpha = max(ntk_alpha, 1) + return ntk_alpha + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -701,8 +794,10 @@ def forward( past_length = 0 past_key_values = tuple([None] * len(self.h)) else: - past_length = past_key_values[0][0].size(-2) - + if self.use_cache_quantization: + past_length = past_key_values[0][0][0].size(2) + else: + past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange( past_length, @@ -719,17 +814,41 @@ def forward( attention_mask = attention_mask[:, None, None, :] attention_mask = attention_mask.to(dtype=self.dtype) attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - # attention_mask中mask掉的部分是-inf, 看到的部分是0 encoder_attention_mask = None - head_mask = self.get_head_mask(head_mask, self.config.n_layer) + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) hidden_states = inputs_embeds - if self.wpe is not None: - position_embeds = self.wpe(position_ids) - hidden_states = hidden_states + position_embeds + + kv_seq_len = hidden_states.size()[1] + if past_key_values[0] is not None: + # past key values[0][0] shape: bs * seq_len * head_num * dim + if self.use_cache_quantization: + kv_seq_len += past_key_values[0][0][0].shape[2] + else: + kv_seq_len += past_key_values[0][0].shape[1] + + if self.training or not self.use_dynamic_ntk: + ntk_alpha_list = [1.0] + elif kv_seq_len != hidden_states.size()[1]: + ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list + else: + ntk_alpha_list = [] + if attention_mask is not None and kv_seq_len > self.seq_length: + true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) + for i in range(hidden_states.size()[0]): + true_seq_len = true_seq_lens[i].item() + ntk_alpha = self.get_ntk_alpha(true_seq_len) + ntk_alpha_list.append(ntk_alpha) + else: + ntk_alpha = self.get_ntk_alpha(kv_seq_len) + ntk_alpha_list.append(ntk_alpha) + self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list + rotary_pos_emb_list = [ + self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list + ] hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) @@ -761,6 +880,7 @@ def custom_forward(*inputs): outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, + rotary_pos_emb_list, None, attention_mask, head_mask[i], @@ -771,6 +891,7 @@ def custom_forward(*inputs): outputs = block( hidden_states, layer_past=layer_past, + rotary_pos_emb_list=rotary_pos_emb_list, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, @@ -781,13 +902,16 @@ def custom_forward(*inputs): hidden_states = outputs[0] if use_cache is True: - presents = presents + (outputs[2 if output_attentions else 1],) + presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[1],) + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( @@ -839,7 +963,7 @@ def __init__(self, config): logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") elif SUPPORT_FP16: logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") - + if config.use_flash_attn == "auto": if config.bf16 or config.fp16: logger.warn("Try importing flash-attention for faster inference...") @@ -853,7 +977,7 @@ def __init__(self, config): _import_flash_attn() self.transformer = QWenModel(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.bf16: self.transformer.bfloat16() @@ -872,22 +996,13 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ): - token_type_ids = kwargs.get("token_type_ids", None) if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + if input_ids.size(0) == 1: + attention_mask = None else: - position_ids = None + attention_mask = kwargs.get("attention_mask", None) if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -898,9 +1013,7 @@ def prepare_inputs_for_generation( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, "attention_mask": attention_mask, - "token_type_ids": token_type_ids, } ) return model_inputs @@ -987,35 +1100,45 @@ def chat( query: str, history: Optional[HistoryType], system: str = "You are a helpful assistant.", - append_history: bool = True, stream: Optional[bool] = _SENTINEL, stop_words_ids: Optional[List[List[int]]] = None, + generation_config: Optional[GenerationConfig] = None, **kwargs, ) -> Tuple[str, HistoryType]: + generation_config = generation_config if generation_config is not None else self.generation_config + assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT - assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT if history is None: history = [] + else: + # make a copy of the user's input such that is is left untouched + history = copy.deepcopy(history) + if stop_words_ids is None: stop_words_ids = [] + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size raw_text, context_tokens = make_context( tokenizer, query, history=history, system=system, - max_window_size=6144, - chat_format=self.generation_config.chat_format, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, ) stop_words_ids.extend(get_stop_words_ids( - self.generation_config.chat_format, tokenizer + generation_config.chat_format, tokenizer )) input_ids = torch.tensor([context_tokens]).to(self.device) outputs = self.generate( input_ids, - stop_words_ids = stop_words_ids, - return_dict_in_generate = False, + stop_words_ids=stop_words_ids, + return_dict_in_generate=False, + generation_config=generation_config, **kwargs, ) @@ -1024,13 +1147,16 @@ def chat( tokenizer, raw_text_len=len(raw_text), context_length=len(context_tokens), - chat_format=self.generation_config.chat_format, + chat_format=generation_config.chat_format, verbose=False, errors='replace' ) - if append_history: - history.append((query, response)) + # as history is a copy of the user inputs, + # we can always return the new turn to the user. + # separating input history and output history also enables the user + # to implement more complex history management + history.append((query, response)) return response, history @@ -1042,30 +1168,35 @@ def chat_stream( system: str = "You are a helpful assistant.", stop_words_ids: Optional[List[List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = None, + generation_config: Optional[GenerationConfig] = None, **kwargs, ) -> Generator[str, Any, None]: - assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + generation_config = generation_config if generation_config is not None else self.generation_config + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT if history is None: history = [] if stop_words_ids is None: stop_words_ids = [] + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size raw_text, context_tokens = make_context( tokenizer, query, history=history, system=system, - max_window_size=6144, - chat_format=self.generation_config.chat_format, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, ) stop_words_ids.extend(get_stop_words_ids( - self.generation_config.chat_format, tokenizer + generation_config.chat_format, tokenizer )) if stop_words_ids is not None: stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_ids=stop_words_ids, - eos_token_id=self.generation_config.eos_token_id, + eos_token_id=generation_config.eos_token_id, ) if logits_processor is None: logits_processor = LogitsProcessorList([stop_words_logits_processor]) @@ -1076,7 +1207,8 @@ def chat_stream( from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig self.__class__.generate_stream = NewGenerationMixin.generate self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True) + stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) + def stream_generator(): outputs = [] for token in self.generate_stream( @@ -1105,17 +1237,19 @@ def generate( streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: + generation_config = generation_config if generation_config is not None else self.generation_config + # Process stop_words_ids. stop_words_ids = kwargs.pop("stop_words_ids", None) if stop_words_ids is None and generation_config is not None: stop_words_ids = getattr(generation_config, "stop_words_ids", None) if stop_words_ids is None: - stop_words_ids = getattr(self.generation_config, "stop_words_ids", None) + stop_words_ids = getattr(generation_config, "stop_words_ids", None) if stop_words_ids is not None: stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_ids=stop_words_ids, - eos_token_id=self.generation_config.eos_token_id, + eos_token_id=generation_config.eos_token_id, ) if logits_processor is None: logits_processor = LogitsProcessorList([stop_words_logits_processor]) @@ -1140,16 +1274,17 @@ def __init__(self, dim, base=10000): super().__init__() self.dim = dim self.base = base - self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) if importlib.util.find_spec("einops") is None: raise RuntimeError("einops is required for Rotary Embedding") self._rotary_pos_emb_cache = None self._seq_len_cached = 0 self._ntk_alpha_cached = 1.0 + self._ntk_alpha_cached_list = [1.0] - def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): - seqlen = max_seq_len + offset + def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) self.inv_freq = 1.0 / ( @@ -1163,14 +1298,19 @@ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): self._ntk_alpha_cached = ntk_alpha seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) from einops import rearrange - self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d") + emb = rearrange(emb, "n d -> 1 n 1 d") - def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): - self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) - return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len] + cos, sin = emb.cos(), emb.sin() + self._rotary_pos_emb_cache = [cos, sin] + + def forward(self, max_seq_len, ntk_alpha=1.0): + self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha) + cos, sin = self._rotary_pos_emb_cache + return [cos[:, :max_seq_len], sin[:, :max_seq_len]] def _rotate_half(x): @@ -1182,20 +1322,28 @@ def _rotate_half(x): def apply_rotary_pos_emb(t, freqs): - if apply_rotary_emb_func is not None: - t_ = t.float() - freqs = freqs.squeeze(0).squeeze(1) - cos = freqs[:, : freqs.shape[-1] // 2].cos() - sin = freqs[:, : freqs.shape[-1] // 2].sin() - output = apply_rotary_emb_func(t_, cos, sin).type_as(t) - return output + """ Apply rotary embedding to the first rotary_dim of the iput + + Arguments: + t (tensor(batch_size, seq_len, n_head, head_dim)): + the input embedding/hidden states + freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): + the cached cos/sin position embeddings + """ + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_float = t.float() + if apply_rotary_emb_func is not None and t.is_cuda: + # apply_rotary_emb in flash_attn requires cos/sin to be of + # shape (seqlen, rotary_dim / 2) and apply rotary embedding + # to the first rotary_dim of the input + cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] + sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] + return apply_rotary_emb_func(t_float, cos, sin).type_as(t) else: - rot_dim = freqs.shape[-1] - t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] - t_ = t_.float() - t_pass_ = t_pass_.float() - t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin()) - return torch.cat((t_, t_pass_), dim=-1).type_as(t) + t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] + t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) + return torch.cat((t_rot, t_pass), dim=-1).type_as(t) class RMSNorm(torch.nn.Module): diff --git a/mft_peft_hf/src/model/qwen/qwen_generation_utils.py b/mftcoder_accelerate/src/model/qwen/qwen_generation_utils.py similarity index 100% rename from mft_peft_hf/src/model/qwen/qwen_generation_utils.py rename to mftcoder_accelerate/src/model/qwen/qwen_generation_utils.py diff --git a/mft_peft_hf/src/model/qwen/tokenization_qwen.py b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py similarity index 76% rename from mft_peft_hf/src/model/qwen/tokenization_qwen.py rename to mftcoder_accelerate/src/model/qwen/tokenization_qwen.py index 4a66a09..2a526d6 100644 --- a/mft_peft_hf/src/model/qwen/tokenization_qwen.py +++ b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py @@ -27,11 +27,22 @@ # regular texts, the surface forms of special tokens need to be # as different as possible to minimize the impact EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) -SPECIAL_TOKENS = ( - ENDOFTEXT, - IMSTART, - IMEND, -) + EXTRAS +# changed to use actual index to avoid misconfiguration with vocabulary expansion +SPECIAL_START_ID = 151643 +SPECIAL_TOKENS = tuple( + enumerate( + ( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + EXTRAS + ), + start=SPECIAL_START_ID, + ) +) +SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: @@ -42,6 +53,7 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: for token, rank in (line.split() for line in contents.splitlines() if line) } + class QWenTokenizer(PreTrainedTokenizer): """QWen tokenizer.""" @@ -51,20 +63,35 @@ def __init__( self, vocab_file, errors="replace", + extra_vocab_file=None, **kwargs, ): super().__init__(**kwargs) - self.errors = errors # how to handle errors in decoding + # how to handle errors in decoding UTF-8 byte sequences + # use ignore if you are in streaming inference + self.errors = errors - self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int] + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int] self.special_tokens = { token: index - for index, token in enumerate( - SPECIAL_TOKENS, start=len(self.mergeable_ranks) - ) + for index, token in SPECIAL_TOKENS } + # try load extra vocab from file + if extra_vocab_file is not None: + used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values()) + extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) + for token, index in extra_mergeable_ranks.items(): + if token in self.mergeable_ranks: + logger.info(f"extra token {token} exists, skipping") + continue + if index in used_ids: + logger.info(f'the index {index} for extra token {token} exists, skipping') + continue + self.mergeable_ranks[token] = index + # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this + enc = tiktoken.Encoding( "Qwen", pat_str=PAT_STR, @@ -86,6 +113,23 @@ def __init__( self.im_start_id = self.special_tokens[IMSTART] self.im_end_id = self.special_tokens[IMEND] + def __getstate__(self): + # for pickle lovers + state = self.__dict__.copy() + del state["tokenizer"] + return state + + def __setstate__(self, state): + # tokenizer is not python native; don't pass it; rebuild it + self.__dict__.update(state) + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + self.tokenizer = enc + def __len__(self) -> int: return self.tokenizer.n_vocab @@ -108,13 +152,17 @@ def convert_tokens_to_ids( ids.append(self.mergeable_ranks.get(token)) return ids - def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + def _add_tokens( + self, + new_tokens: Union[List[str], List[AddedToken]], + special_tokens: bool = False, + ) -> int: if not special_tokens and new_tokens: - raise ValueError('Adding regular tokens is not supported') + raise ValueError("Adding regular tokens is not supported") for token in new_tokens: surface_form = token.content if isinstance(token, AddedToken) else token - if surface_form not in SPECIAL_TOKENS: - raise ValueError('Adding unknown special tokens is not supported') + if surface_form not in SPECIAL_TOKENS_SET: + raise ValueError("Adding unknown special tokens is not supported") return 0 def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: diff --git a/mftcoder_accelerate/src/model/qwen/tokenizer_config.json b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json new file mode 100644 index 0000000..9c37cac --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json @@ -0,0 +1,10 @@ +{ + "model_max_length": 8192, + "tokenizer_class": "QWenTokenizer", + "auto_map": { + "AutoTokenizer": [ + "tokenization_qwen.QWenTokenizer", + null + ] + } +} diff --git a/mftcoder_accelerate/src/mpt/mpt_accelerate.py b/mftcoder_accelerate/src/mpt/mpt_accelerate.py new file mode 100644 index 0000000..5d187c9 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_accelerate.py @@ -0,0 +1,494 @@ +""" +# @author Chaoyu Chen +# @date 2024/6/1 +# @module mpt_accelerate.py + +Accelerate + DeepSpeed + Full-parameter + Multi-task + Pre-training/Continue Training/Finetuning + +Entry +""" + +import os +import sys +import argparse +import math +import logging +import json +import time +from tqdm.auto import tqdm +import transformers +import numpy as np +import torch +from torch import nn +from dataclasses import dataclass +from datasets import Dataset +import datasets +from torch.utils.data import DataLoader +from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + set_seed, + BitsAndBytesConfig, + get_scheduler, +) + +from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration +from accelerate.logging import get_logger +from datetime import timedelta +from accelerate.utils import InitProcessGroupKwargs +from transformers.optimization import Adafactor + +# insert src as import path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +sys.path.insert(0, parent_dir) + +from tokenizer import build_tokenizer +from data.multi_task_dataset import load_dataset_from_jsonl, compile_helper +from data.data_utils import load_dataset_from_bin +from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK +from mpt.mpt_trainer import MptTrainer +from mpt.mpt_arguments import MptTrainArgs +from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS + + +logger = get_logger(__name__) + + +def get_task_mask(args, task_id): + task_num = len(TASK2ID) + task_mask = torch.zeros(task_id.shape[0], task_num) + task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1 + + return task_mask + + +def get_attention_mask_and_position_ids(data): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + attention_mask = torch.ones((batch_size, seq_length), device=data.device) + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data).clone() + + return attention_mask, position_ids + + +@dataclass +class DataCollatorForMFTDataset(object): + args: None + + def __call__(self, instances): + (input_ids, loss_mask, weights, task_id) = tuple( + [instance.get(key, None) for instance in instances] + for key in ("input_ids", "loss_mask", "weight", "task_id") + ) + + result_batch = {} + """ + outputs = model( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']), + # labels=(batch['labels'], batch['loss_mask']), + position_ids=batch['position_ids']) + """ + + # if loss_mask is not None: + loss_mask = torch.tensor(np.array(loss_mask)).long() + last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1) + if self.args.use_dynamic_padding: + # get last non-padding position + max_pos = last_one_pos.max().item() + 1 + else: + max_pos = loss_mask.shape[-1] + + if self.args.tokenize_mode == "sst" and self.args.padding_mode == "pack": + # 兼容sst + pack tokenization, 最后一位是脏数据,需要去掉 + result_batch["loss_mask"] = loss_mask.float()[:, 1 : max_pos - 1].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + result_batch["input_ids"] = input_ids[:, : max_pos - 2].contiguous() + result_batch["labels"] = input_ids[:, 1 : max_pos - 1].contiguous() + else: + result_batch["loss_mask"] = loss_mask.float()[:, 1:max_pos].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + # print(f"shape of input_ids: {input_ids.shape}") + result_batch["input_ids"] = input_ids[:, : max_pos - 1].contiguous() + result_batch["labels"] = input_ids[:, 1:max_pos].contiguous() + + # Get the masks and position ids. + + # if you want to be compatible with non-gpt models, something you can do here + if self.args.model_type in ["antglm"]: + (result_batch["attention_mask"], result_batch["position_ids"]) = get_attention_mask_and_position_ids( + data=result_batch["input_ids"] + ) + elif self.args.model_type in ["mixtral", "mtx-qwen2", "qwen2_moe"]: + batch_size, seq_length = result_batch["input_ids"].shape + # bsz * seq_length + range_tensor = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1) + # attention_mask for padding tokens + attention_mask = (range_tensor <= last_one_pos.reshape(batch_size, 1)).long() + result_batch["attention_mask"], result_batch["position_ids"] = attention_mask, None + else: + # For decoder-only models, transformers will create them. + result_batch["attention_mask"], result_batch["position_ids"] = None, None + + if task_id is not None: + task_id = torch.tensor(np.array(task_id)) + result_batch["task_mask"] = get_task_mask(self.args, task_id) # bsz * task_num + result_batch["task_id"] = task_id + + return result_batch + + +def pprint_args(args, accelerator): + # 计算所有键的最大字符串长度 + max_key_length = max(len(str(key)) for key in vars(args).keys()) + + message = "" + message += "====" * 60 + "\n" + message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n" + message += "====" * 60 + "\n" + accelerator.print(message) + accelerator.print("GPU: {}".format(torch.cuda.current_device())) + + +def prepare_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_config", type=str, default=None) + + parser.add_argument("--data_paths", type=str, default=None) + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--tb_dir", type=str, default=None) + parser.add_argument("--pretrained_model_path", type=str, default=None) + parser.add_argument("--micro_batch_size", type=int, default=None) + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--distributed_type", type=str, default="deepspeed") + + parsed = parser.parse_args() + # get json configs + with open(parsed.train_config, "r") as f: + train_config = json.load(f) + + # parse args from cofig.json + # args = argparse.Namespace(**train_config) + args = MptTrainArgs(**train_config) + + # override args by cli arguments + if parsed.data_paths: + args.data_paths = parsed.data_paths + if parsed.output_dir: + args.output_dir = parsed.output_dir + if parsed.tb_dir: + args.tb_dir = parsed.tb_dir + if parsed.pretrained_model_path: + args.pretrained_model_path = parsed.pretrained_model_path + args.vocab_file = parsed.pretrained_model_path + if parsed.micro_batch_size: + args.per_device_train_batch_size = parsed.micro_batch_size + args.per_device_eval_batch_size = parsed.micro_batch_size + if parsed.model_type: + args.model_type = parsed.model_type + + args.distributed_type = parsed.distributed_type + + # refactor args + + args.vocab_file = args.pretrained_model_path + + args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(","))) + "]" + + # generate TASK2ID, ID2TASK + generate_task_id(args.data_paths) + + if args.weighted_loss_mode == "coba": + args.task_weights = [1.0] * len(ID2TASK) + elif args.task_weights is not None: + args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")] + assert len(args.task_weights) == len(ID2TASK), f"length of task_weights must equal to length of data_paths" + else: + args.task_weights = [1.0] * len(ID2TASK) + + return args + + +def main(): + t0 = time.time() + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["HF_HUB_OFFLINE"] = "false" + # get input args, set TASK2ID, ID2TASK, refactor args + args = prepare_args() + + # fix randomness + if args.seed is not None: + set_seed(args.seed) + + # define accelerator + init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) + + if args.distributed_type and args.distributed_type.lower() == "fsdp": + fsdp_plugin = FullyShardedDataParallelPlugin( + # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + limit_all_gathers=True, + sync_module_states=True, + use_orig_params=True, + cpu_offload=False, + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + else: + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + + # print key infos + accelerator.print("In mft_accelerate.py, sys path:", sys.path) + accelerator.print(f"transformers.__version__: {transformers.__version__}") + + # get world_size + args.world_size = accelerator.num_processes + + # backup args + pprint_args(args, accelerator) + if accelerator.is_main_process: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "args.json"), "w") as f: + json.dump(args.dict(), f, indent=2) + + # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest + latest = None + if os.path.exists(os.path.join(args.output_dir, "latest")): + with open(os.path.join(args.output_dir, "latest"), "r") as fl: + latest = json.load(fl) + accelerator.print(f"[INFO] Existing latest: {latest}") + + if args.auto_resume and args.resume_from_checkpoint is None and latest: + args.resume_from_checkpoint = latest["latest_ckpt"] + + # logger + logging.basicConfig( + format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + # compile Cpp helper + compile_helper() + time.sleep(10) + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # get global_rank and local rank for current process + global_rank = accelerator.process_index + local_rank = accelerator.local_process_index + print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}") + + # TASK2ID, ID2TASK + # generate_task_id(args.data_paths) + + # multi task blendable dataset(sharded) + if args.load_raw_dataset: + print_rank_0("> load raw jsonl dataset") + train_dataset, valid_dataset = load_dataset_from_jsonl( + args=args, shard_data=True, world_size=args.world_size, global_rank=global_rank, local_rank=local_rank + ) + else: + print_rank_0("> load tokenized bin dataset, refer to gpt_neox indexed dataset") + train_dataset, valid_dataset, _ = load_dataset_from_bin(args=args) + + t1 = time.time() + logger.info(f"dataset loading time: {t1 - t0:.4f}") + + # cuda memory + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + max_memory = f"{free_in_GB - 2}GB" + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory for i in range(n_gpus)} + accelerator.print("max memory: ", max_memory, n_gpus) + + # # 是否要加入新的special tokens + # num_added_toks = tokenizer.tokenizer.add_special_tokens([" ", " "]) + # accelerator.print("We have added", num_added_toks, "tokens") + # accelerator.print(f"role marker tokens {tokenizer.convert_tokens_to_ids(' ')} {tokenizer.convert_tokens_to_ids(' ')}, resized tokenizer_size: {len(tokenizer)}") + + # creating model + ModelClass = MODEL_TYPES[args.model_type] + if args.model_type in SUPPORT_IN_TRANSFORMERS: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + attn_implementation=args.attn_implementation, + torch_dtype=torch.bfloat16, + ) + else: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + torch_dtype=torch.bfloat16, + ) + + # build a tokenizer for possible resizing or saving + tokenizer = build_tokenizer(args) + # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, + # i.e. the length of the tokenizer. + # 如果新增special tokens, 需要resize input embedding 和output embedding + # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) + + model.gradient_checkpointing_enable() + + if args.saving_limit is None or not isinstance(args.saving_limit, int) or args.saving_limit < 1: + # saving_limit is set automatically if needed + args.saving_limit = 2 + accelerator.print( + "[WARNING]saving_limit must be a integer greater than 1 in Full-Parameters Training, we set it to 2" + ) + + t2 = time.time() + if accelerator.is_main_process: + logging.info(f"model loading time: {t2 - t1:.4f}") + + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + if hasattr(model.config, "use_logn_attn"): + model.config.use_logn_attn = False # special for qwen model + # load balance for moe training + if hasattr(model.config, "output_router_logits"): + model.config.output_router_logits = True + model_config = model.config + accelerator.print(model.config) + + # dataloader + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_train_batch_size, + pin_memory=True, + drop_last=True, + ) + if valid_dataset: + valid_dataloader = DataLoader( + valid_dataset, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_eval_batch_size, + pin_memory=True, + drop_last=True, + ) + else: + valid_dataloader = None + + # optimizer + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED") + # from deepspeed.ops.adam import FusedAdam as Adam + # adam_optimizer = Adam + adam_optimizer = torch.optim.AdamW + elif accelerator.distributed_type == DistributedType.FSDP: + accelerator.print("DISTRIBUTED TRAINING USING FSDP") + model = accelerator.prepare(model) + adam_optimizer = torch.optim.AdamW + else: + raise ValueError("Only support DeepSpeed and FSDP") + + optimizer = adam_optimizer( + model.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.999), + ) + # for group in optimizer.param_groups: + # group.setdefault("initial_lr", group["lr"]) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + if isinstance(args.num_warmup_steps, float) and args.num_warmup_steps < 1.0: + args.num_warmup_steps = int(args.max_train_steps * args.num_warmup_steps) // accelerator.num_processes + accelerator.print(f"num_warmup_steps: {args.num_warmup_steps}") + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + # scheduler_specific_kwargs={"last_epoch": scheduler_last_ep} + ) + # prepare all + if accelerator.distributed_type == DistributedType.DEEPSPEED: + if valid_dataloader: + (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, valid_dataloader, optimizer, lr_scheduler + ) + else: + (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, optimizer, lr_scheduler + ) + + # prepare all except model, which is prepared before + elif accelerator.distributed_type == DistributedType.FSDP: + if valid_dataloader: + (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, valid_dataloader, lr_scheduler + ) + else: + (optimizer, train_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) + print(model.device) + accelerator.print(model) + # accelerator.print(model.config) + + # Recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterward we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # zero 3 flag + is_ds_zero_3 = False + if getattr(accelerator.state, "deepspeed_plugin", None): + is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3 + accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}") + elif getattr(accelerator.state, "fsdp_plugin", None): + accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}") + + trainer = MptTrainer( + accelerator=accelerator, + model=model, + model_config=model_config, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + num_update_steps_per_epoch=num_update_steps_per_epoch, + total_train_dataset_size=len(train_dataset), + args=args, + ) + trainer.accelerate_train() + + +if __name__ == "__main__": + main() diff --git a/mftcoder_accelerate/src/mpt/mpt_arguments.py b/mftcoder_accelerate/src/mpt/mpt_arguments.py new file mode 100644 index 0000000..8045421 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_arguments.py @@ -0,0 +1,161 @@ +""" +# @author Chaoyu Chen +# @date 2024/6/1 + +MPT training arguments +""" + +from dataclasses import dataclass, asdict +from typing import List, Union + + +@dataclass +class MptTrainArgs: + # train data paths on shared FS + data_paths: Union[str, List[str]] + + # output dir for saving adaptors in peft or full ckpts in full-parameter training + output_dir: str + + # tensorboard dir for saving tensorboard logs + tb_dir: str + + # pretrained_model_path, on which is the model you want to train + pretrained_model_path: str + + # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2 + model_type: str + + # load from raw jsonl file or tokenized binary file + load_raw_dataset: bool = True + + # weights of loss calculation for each task, None means equal weights + task_weights: Union[None, str] = None + + # weights of data sampling, leave it None + data_weights: Union[None, str] = None + + # hf loading model low_cpu_mem_usage + low_cpu_mem_usage: bool = True + + # train/valid/test split + data_split: str = "98,2,0" + + # padding or pack or concat + padding_mode: str = "padding" + + # sft or sst + tokenize_mode: str = "sft" + + # case3 or case4 + weighted_loss_mode: str = "case3" + + # mircro train batch size + per_device_train_batch_size: int = 8 + + # micro eval batch size, always same as micro train batch size + per_device_eval_batch_size: int = 8 + + # HF AutoTokenizer is supported, maybe more types + tokenizer_type: str = "AutoTokenizer" + + # initial lr + learning_rate: float = 5e-5 + + # minimum lr + min_lr: float = 5e-6 + + # weight decay + weight_decay: float = 0.01 + + # gradient_accumulation_steps + gradient_accumulation_steps: int = 1 + + # lr_scheduler_type + lr_scheduler_type: str = "cosine" + + # num_warmup_steps + num_warmup_steps: Union[int, float] = 0.05 + + # num_train_epochs + num_train_epochs: int = 4 + + # seed for reproducing + seed: int = 1234 + + # seq_length, context length + seq_length: int = 4096 + + # path of adaptor which is resumed from, None for not resuming training + resume_from_checkpoint: Union[None, str] = None + + # auto resume from latest ckpt if job restarted + auto_resume: bool = True + + # num of steps for logging training loss + log_interval: int = 10 + + # num of steps for saving ckpt + checkpointing_steps: int = 100 + + # num of steps for evaluation(eval_loss), better same as checkpointing steps + evaluation_steps: int = 100 + + # max train steps, if None, depends on num_train_epochs + max_train_steps: Union[None, int] = None + + # if checkpointing every epoch, maybe True in sst + epoch_checkpointing: bool = False + + # save transformers model(safetensors) + save_transformers_model: bool = False + + # shuffle before train/valid split + shuffle_before_split: bool = True + + # DDP random sampler + use_random_sampler: bool = True + + # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point + early_stopping: bool = True + early_stopping_stall_num: int = 5 + + # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota. + saving_limit: Union[None, int] = None + + # if dynamic padding + use_dynamic_padding: bool = True + + # warm-up steps for CoBa, recommand the number of valid batches + coba_warmup_steps: int = 100 + # history length of sample valid loss used to fit the slope curve in CoBa, recommand [2*coba_warmup_steps,5*coba_warmup_steps] + coba_history_length: int = 200 + # temperature for divergence factor in CoBa + coba_tau: int = 5 + # iteration interval of update per task train weight in CoBa + coba_update_interval: int = 1 + # the number of mini valid batches sampled at each updated iteration interval + coba_sample_valid_num: int = 1 + + # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} + attn_implementation: str = "flash_attention_2" + + # role markers, which are prompt template before each role: system, user and assistant + # role_markers: {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} + role_markers: Union[None, dict] = None + + distributed_type: Union[None, str] = None + + init_timeout_seconds: Union[None, int] = 3600 + + # legacy, leave them + use_xformers: bool = True + trust_remote_code: bool = True + weight_by_num_documents: bool = True + make_vocab_size_divisible_by: int = 32 + model_parallel_size: int = 1 + use_slow_tokenizer: bool = False + world_size: int = 8 + + def dict(self): + return {k: str(v) for k, v in asdict(self).items()} diff --git a/mftcoder_accelerate/src/mpt/mpt_trainer.py b/mftcoder_accelerate/src/mpt/mpt_trainer.py new file mode 100644 index 0000000..b5e2da8 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_trainer.py @@ -0,0 +1,606 @@ +""" +# @author qumu +# @date 2024/6/6 +# @module mpt_trainer.py + +MPT/MCT/MFT Full-parameter Trainer +""" + +import gc +import os +import sys +import threading +import argparse +import math +import logging +import json +import time +import transformers +import numpy as np +import psutil +import shutil +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from typing import List, Optional, Tuple, Union +from tqdm.auto import tqdm +from accelerate.logging import get_logger +from accelerate import Accelerator +from transformers import set_seed + +# sys.path.append("..") +from utils.common_utils import generate_task_id, TASK2ID, ID2TASK +from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func + +logger = get_logger(__name__) + + +def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): + # create path if not exist + if not os.path.exists(save_path): + os.makedirs(save_path) + + # copy each file in files_list to save_path + for filename in files_list: + src_file = os.path.join(mode_path, filename) + + # copy only if src exists + if os.path.exists(src_file): + dest_file = os.path.join(save_path, filename) + + # copy + shutil.copy(src_file, dest_file) + print(f"Copied {filename} to {save_path}") + else: + print(f"File {filename} does not exist in {mode_path}") + + +def check_existing_ckpts(output_dir): + prefix = "step_" + + if not os.path.exists(output_dir): + return [] + # list all files and dirs + contents = os.listdir(output_dir) + + # find dirs starts with "step_" + matching_folders = [ + folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix) + ] + + return matching_folders + + +def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps): + """ + extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training + """ + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + logger.info(f"Resume from exact Epoch {starting_epoch}: completed_steps {completed_steps}") + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + completed_steps = int(training_difference.replace("step_", "")) + starting_epoch = completed_steps // num_update_steps_per_epoch + resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps + logger.info(f"Resume from Epoch {starting_epoch} + step {resume_step}: completed_steps {completed_steps}") + + return starting_epoch, completed_steps, resume_step + + +def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): + for key, value in log_dict.items(): + summary_writer.add_scalar(f"{key}", value, completed_steps) + + +def delete_ckpts_over_limits(output_dir, saving_limit, best_step): + """delete ckpts more than saving_limits except for the best_step ckpt""" + existing_ckpts = check_existing_ckpts(output_dir) + logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}") + # sorted only step num ascendingly + ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts]) + # delete the oldest steps except for the best step at present + if len(ckpt_steps) > saving_limit: + deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step] + # print(deletable_steps[:len(ckpt_steps) - saving_limit]) + for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]: + shutil.rmtree(os.path.join(output_dir, f"step_{del_step}")) + logger.info(f"Removed ckpt step_{del_step}") + + +class MptTrainer: + """ + Multitask Pre-train/Continue-train Trainer with Full-parameters training. + """ + + def __init__( + self, + accelerator: Accelerator, + model, + model_config, + train_dataloader, + valid_dataloader, + optimizer, + lr_scheduler, + tokenizer, + num_update_steps_per_epoch, + total_train_dataset_size, + args, + ): + self.accelerator = accelerator + self.model = model + # hf model config + self.model_config = model_config + self.train_dataloader = train_dataloader + self.valid_dataloader = valid_dataloader + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.tokenizer = tokenizer + self.num_update_steps_per_epoch = num_update_steps_per_epoch + self.total_train_dataset_size = total_train_dataset_size + # training arguments + self.args = args + # tensorboard writer + self.summary_writer = SummaryWriter(log_dir=args.tb_dir) + + def print(self, msg: str): + """ + accelerator print, default on main process + Args: + msg: + + Returns: + + """ + self.accelerator.print(msg) + + def touch(self, batch, num_tokens=10): + """touch first and last tokens and labels for debugging usage""" + self.print( + f"step 1 batch shape: {batch['input_ids'].shape},\n" + f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}" + f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}" + ) + self.print(f"first {num_tokens} input_ids and loss_mask") + for pt in range(1): + self.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + self.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + + @staticmethod + def format_tensor(tensor, n): + return list(map(lambda x: round(x, n), tensor.tolist())) + + def accelerate_saving_states(self, output_dir: str, completed_steps: int): + """ + Saving lora adaptor or full checkpoint using accelerator + Args: + output_dir: exact dir for saving ckpt + completed_steps: + + Returns: + + """ + self.accelerator.wait_for_everyone() + logger.info(f"[CHECKPOINT] Saving checkpoint states") + self.accelerator.save_state(output_dir) + self.accelerator.wait_for_everyone() + + # save safetensors for direct inference if needed + if self.args.save_transformers_model: + logger.info(f"[CHECKPOINT] Saving transformers(hf) model", main_process_only=True) + unwrapped_model = self.accelerator.unwrap_model(self.model) + # self.print(f"unwrapped model type {type(unwrapped_model)}") + unwrapped_model.save_pretrained( + output_dir, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, + state_dict=self.accelerator.get_state_dict(self.model), + ) + self.accelerator.wait_for_everyone() + + # tokenizer saving and bug dummy ckpt cleaning. + if self.accelerator.is_main_process: + if self.args.model_type.lower() == "deepseek": + copy_tokenizer_files( + self.args.pretrained_model_path, ["tokenizer.json", "tokenizer_config.json"], output_dir + ) + else: + self.tokenizer.save_pretrained(output_dir) + + sf = os.path.join(output_dir, "model.safetensors") + index_file = os.path.join(output_dir, "model.safetensors.index.json") + if os.path.isfile(sf) and os.path.isfile(index_file): + self.print(f"Remove bug dummy ckpt {sf}") + os.remove(sf) + + # save latest info + if self.accelerator.is_main_process: + latest = { + "latest_ckpt": output_dir, + "lr": self.optimizer.param_groups[0]["lr"], + } + with open(os.path.join(self.args.output_dir, "latest"), "w") as f: + json.dump(latest, f, indent=2) + + logger.info( + f"[CHECKPOINT][complete_steps={completed_steps}], states {output_dir} saved, latest: {latest}", + main_process_only=True, + ) + self.accelerator.wait_for_everyone() + + def accelerate_monitor( + self, + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status=None, + ): + """ + gather reduce_loss and reduce_task_loss from all N devices. + train logging and tensorboarding. + """ + # gather reduce_loss and reduce_task_loss from all N devices + reduce_losses = self.accelerator.gather(reduce_loss).detach().float() + reduce_task_losses = self.accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK)) + reduce_task_exists = self.accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK)) + + # get train loss and per-task train loss + train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps) + # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (self.args.log_interval * self.args.gradient_accumulation_steps) + train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0) + + # logging and writing tensorboard + logger.info( + f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]" + f"[train_task_loss={self.format_tensor(train_task_loss, 4)}]" + f"[gather shape={list(reduce_losses.shape)}]" + f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]", + main_process_only=True, + ) + if coba_status is not None: + if completed_steps > coba_status.coba_warmup_steps: + coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum( + coba_status.log_per_task_weight + ) + else: + coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) + logger.info( + f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True + ) + train_log_dict = {"Loss/train": train_loss} + for i in range(len(ID2TASK)): + train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i] + if coba_status is not None: + train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item() + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, train_log_dict, completed_steps) + + if coba_status is not None: + coba_status.log_per_task_weight = torch.zeros(len(ID2TASK)) + + def accelerate_evaluate( + self, + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ): + """ + evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices. + eval logging and tensorboarding. + """ + losses = [] + accumulated_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + accumulated_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + for valid_step, valid_batch in enumerate(self.valid_dataloader): + with torch.no_grad(): + outputs = self.model( + input_ids=valid_batch["input_ids"], + attention_mask=valid_batch["attention_mask"], + position_ids=valid_batch["position_ids"], + return_dict=True, + ) + + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=valid_batch["labels"], + task_mask=valid_batch["task_mask"], + task_id=valid_batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=valid_batch["loss_mask"], + task_weights=self.args.task_weights, + ) + + losses.append(self.accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size))) + accumulated_task_loss += task_loss.detach().float() + accumulated_task_exist += (task_loss != 0.0).detach().float() + + self.accelerator.wait_for_everyone() + valid_batch_num = len(losses) + gathered_size = losses[0].shape + losses = torch.cat(losses) + # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK)) + task_losses = self.accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK)) + task_exists = self.accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK)) + + try: + eval_loss = torch.mean(losses) + # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num + eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0) + if eval_loss <= min_eval_loss: + min_eval_loss = eval_loss + stall_num = 0 + best_step = completed_steps + else: + stall_num += 1 + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info( + f"[EVAL][completed_steps={completed_steps}]" + f"[eval_loss={eval_loss:.6f}][eval_task_loss={self.format_tensor(eval_task_loss, 4)}]" + f"[perplexity={perplexity:.4f}][valid_batch_num={valid_batch_num}]" + f"[gather_size={list(gathered_size)}]", + main_process_only=True, + ) + eval_log_dict = { + "Loss/valid": eval_loss, + "Perplexity/valid": perplexity, + "Epochs": round(completed_steps / self.num_update_steps_per_epoch, 2), + } + for i in range(len(ID2TASK)): + eval_log_dict[f"{ID2TASK[i]}_loss/valid"] = eval_task_loss[i] + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, eval_log_dict, completed_steps) + + return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step + + def accelerate_train(self): + # Train! + if self.args.seed is not None: + set_seed(self.args.seed) + + global_batch_size = ( + self.args.per_device_train_batch_size + * self.accelerator.num_processes + * self.args.gradient_accumulation_steps + ) + logger.info("************************************** Running training ****************************************") + logger.info(f" Num examples = {self.total_train_dataset_size}") + logger.info(f" Num Epochs = {self.args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") + logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}") + logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") + logger.info(f" Total optimization(update/completed) steps = {self.args.max_train_steps}") + logger.info(f" Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_train_epochs}") + logger.info("************************************************************************************************") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process) + + # set starting_epoch, completed_steps and resume_step of train_dataloader + completed_steps = 0 + starting_epoch = 0 + resume_step = None + + if self.args.resume_from_checkpoint: + self.accelerator.load_state(self.args.resume_from_checkpoint) + self.accelerator.print(f"Resumed from checkpoint: {self.args.resume_from_checkpoint}") + path = os.path.basename(self.args.resume_from_checkpoint) + starting_epoch, completed_steps, resume_step = extract_epochs_and_steps( + path, self.num_update_steps_per_epoch, self.args.gradient_accumulation_steps + ) + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + # monitor minimum eval_loss, stalling num, and best_step + min_eval_loss = float("inf") + stall_num = 0 + best_step = None + + # monitor train loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + per_task_weight = self.args.task_weights + + if self.args.weighted_loss_mode == "coba": + self.model.eval() + eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate( + completed_steps, + 0, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + coba_status = CoBaStatus( + self.args.coba_warmup_steps, + self.args.coba_history_length, + self.args.coba_tau, + self.args.coba_update_interval, + self.args.coba_sample_valid_num, + self.valid_dataloader, + ) + coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device) + coba_status.sample_valid_batch(self.model, completed_steps) + logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + else: + coba_status = None + + # Training Loop! + for epoch in range(starting_epoch, self.args.num_train_epochs): + # set_epoch + # self.train_dataloader.set_epoch(epoch) + + # if we early stop by some ckpts not converging + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + break + + if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step) + else: + active_dataloader = self.train_dataloader + tail_num = len(active_dataloader) - len(active_dataloader) % self.args.gradient_accumulation_steps + print(f"length of dataloader: {len(active_dataloader)}") + + self.model.train() + # Inner Loop! + for step, batch in enumerate(active_dataloader): + if step == tail_num: + break + with self.accelerator.accumulate(self.model): + if step == 0: + self.touch(batch, num_tokens=10) + # forward + outputs = self.model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + position_ids=batch["position_ids"], + return_dict=True, + ) + + if ( + self.args.weighted_loss_mode == "coba" + and self.accelerator.sync_gradients + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps >= self.args.coba_warmup_steps + ): + with torch.no_grad(): + per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps) + coba_status.log_per_task_weight += per_task_weight + # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True) + + # loss + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=batch["labels"], + task_mask=batch["task_mask"], + task_id=batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=batch["loss_mask"], + task_weights=per_task_weight, + ) + + # accelerator.print(len(outputs.router_logits), outputs.router_logits[0], outputs.router_logits[-1]) + # accelerator.print(batch['attention_mask'].shape, batch['attention_mask']) + aux_loss = None + if hasattr(self.model_config, "output_router_logits") and self.model_config.output_router_logits: + if hasattr(self.model_config, "num_local_experts"): + num_experts = self.model_config.num_local_experts + elif hasattr(self.model_config, "num_experts"): + num_experts = self.model_config.num_experts + else: + raise ValueError("model has no attribute num_local_experts or num_experts") + aux_loss = load_balancing_loss_func( + outputs.router_logits, + num_experts, + self.model_config.num_experts_per_tok, + batch["attention_mask"], + ) + aux_loss = self.model_config.router_aux_loss_coef * aux_loss.to(loss.device) + loss += aux_loss # make sure to reside in the same device + + # backward + self.accelerator.backward(loss) + # print(self.lr_scheduler.state_dict(), self.accelerator.process_index) + # update(sync_gradients) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + # support args.min_lr + if self.optimizer.param_groups[0]["lr"] <= self.args.min_lr: + self.optimizer.param_groups[0]["lr"] = self.args.min_lr + + # accumulate resuce_loss and reduce_task_loss in a log_interval + if not torch.isnan(loss): + reduce_loss += loss.detach().float() + if aux_loss and not torch.isnan(aux_loss): + reduce_aux_loss += aux_loss.detach().float() + # self.print("task loss devices: ", reduce_task_loss.device, task_loss.device) + reduce_task_loss += task_loss.detach().float() + reduce_task_exist += (task_loss != 0).detach().float() + + # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done. + if self.accelerator.sync_gradients: + if ( + self.args.weighted_loss_mode == "coba" + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps >= 1 + ): + coba_status.sample_valid_batch(self.model, completed_steps) + # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + + # progress_bar.update(1) + completed_steps += 1 + # monitoring training process and logging and tensorboarding + if completed_steps % self.args.log_interval == 0: + progress_bar.update(self.args.log_interval) + if reduce_aux_loss > 0.0: + self.print(f"[INFO] aux_loss: {reduce_aux_loss/self.args.log_interval}") + self.accelerate_monitor( + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status, + ) + # reset reduce_loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + + # steps checkpointing + if self.args.checkpointing_steps and completed_steps % self.args.checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_states(output_dir, completed_steps) + + # steps evaluation + if completed_steps % self.args.evaluation_steps == 0 and self.valid_dataloader: + self.model.eval() + eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate( + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + + # delete ckpts over args.saving_limit + if self.accelerator.is_main_process and self.args.saving_limit: + delete_ckpts_over_limits(self.args.output_dir, self.args.saving_limit, best_step) + + # early stoppin when stalling more than args.early_stopping_stall_num + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + self.print(f"[WARNING] Early stopping at {completed_steps}") + break + + if completed_steps >= self.args.max_train_steps: + break + self.accelerator.wait_for_everyone() + + # epoch checkpointing + if self.args.epoch_checkpointing: + output_dir = f"epoch_{epoch + 1}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_states(output_dir, completed_steps) + + self.summary_writer.close() diff --git a/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py new file mode 100644 index 0000000..ca4347e --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +import argparse +import multiprocessing +import os +import sys +import random +import time +import tqdm +import glob +import json +import numpy as np + + +# 将父目录的父目录加入path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +grandparent_dir = os.path.dirname(parent_dir) +sys.path.append(grandparent_dir) + +from tokenizer import init_tokenizer +from pack_encoder import PackSSTBinEncoder, load_tokenizer +from data import indexed_dataset + +from threading import Semaphore +from colorama import Fore +import lm_fmt as lmd + + +def yield_from_files(files: list, semaphore): + """ + Iterator over input documents + + :param fnames: list of filenames + """ + def yielder(fname, semaphore): + with open(fname, 'r') as f: + for line in f: + semaphore.acquire() + yield json.loads(line) + + for fname in files: + semaphore.acquire() + yield from yielder(fname, semaphore) + +def yield_from_files2(fnames: list, semaphore, sample_percent): + """ + Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts / + other compressed formats. Also filters out empty documents. + + :param fnames: list of filenames + """ + def yielder(fname, semaphore): + try: + sample_interval = int(1/sample_percent) + for f in filter(lambda x: x, lmd.Reader(fname).stream_data(key=None)): + rand_value = random.randint(1, sample_interval*100) + if rand_value % sample_interval != 0: + continue + semaphore.acquire() + + #rand_value = random.randint(1, sample_interval*100) + #if rand_value % sample_interval != 0: + # yield None + + yield f + except Exception as e: + print('####Exception:', e.args) + yield None + + for fname in fnames: + semaphore.acquire() + + yield from yielder(fname, semaphore) + + +def print_example_doc(input_ids, tokenizer): + print(Fore.YELLOW + f'INPUT IDS len: {len(input_ids)}') + print(Fore.BLUE + f'INPUT IDS:\n {input_ids}\n\n') + + print(Fore.RED + f'DETOKENIZED INPUT:\n{tokenizer.decode(input_ids)}') + + +def core_process(encoded_docs, semaphore, seq_length, tokenizer, encoder, builder, output_idx_file): + """ + core of Data Pack SFT processing + """ + input_ids_key = 'input_ids' + + proc_start = time.time() + total_bytes_processed = 0 + pbar = tqdm.tqdm() + sentence_droped = 0 + loss_token_cnt = 0 + + print("PRINT BEFORE STREAM PROCESS DATA") + + print_example_count = 0 + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + + # release semaphore so `yield_from_files` can add another file to the buffer + semaphore.release() + + # add each tokenized document / sentence, + # For sft, each document has only one sample + input_ids_sentence = doc[input_ids_key][0] + if len(input_ids_sentence) < 1: + sentence_droped += 1 + continue + + builder.add_item(np.array(input_ids_sentence, dtype=builder.dtype)) + builder.end_document() + #builder.finalize_without_close(output_idx_file) + #builder.add_item_and_end_document_and_finalize(np.array(input_ids_sentence, dtype=builder.dtype), output_idx_file) + + # print the first packed sample as example + if print_example_count < 1: + print_example_doc(input_ids_sentence, tokenizer) + print_example_count += 1 + + # log progress + if i % 100 == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + pbar.set_description( + f"Processed {i} documents ({i / elapsed} docs/s, {mbs} MB/s)." + ) + if i != 0: + pbar.update(100) + + # 尾部处理 + builder.finalize(output_idx_file) + + print(Fore.RED + "\ndroped docs: {}".format(sentence_droped)) + + +def process_dataset(dataset_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent): + """ + Re-organize samples in the given data path into a Data Pack file. + """ + + # get all jsonl files and corresponding reading handler + files = glob.glob(os.path.join(dataset_path, '**/*.jsonl'), recursive=True) + + # build a semaphore object to stop `yield_from_files` from getting ahead + # of encoder.encode and hence building up memory + semaphore = Semaphore(1000 + parallel_num) + + # build sample iterator + sample_iterator = yield_from_files2(files, semaphore, sample_percent) + + # load tokenizer + # tokenizer = load_tokenizer(model_path, tokenizer_type) + tokenizer = init_tokenizer(model_path) + print('TOKEN of id=2:', tokenizer.convert_ids_to_tokens(2)) + print('ID of