Support running pipeline parallel inference by vertically partitioning model to different devices (#10392)
* support pipeline parallel inference * fix logging * remove benchmark file * fic * need to warmup twice * support qwen and qwen2 * fix lint * remove genxir * refine
This commit is contained in:
		
							parent
							
								
									66b4bb5c5d
								
							
						
					
					
						commit
						9e763b049c
					
				
					 6 changed files with 907 additions and 3 deletions
				
			
		
							
								
								
									
										78
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,78 @@
 | 
				
			||||||
 | 
					# Run BigDL-LLM on Multiple Intel GPUs in pipeline parallel fashion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This example demonstrates how to run BigDL-LLM optimized low-bit model vertically partitioned on two [Intel GPUs](../README.md).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Requirements
 | 
				
			||||||
 | 
					To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1.1 Install BigDL-LLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					conda create -n llm python=3.9
 | 
				
			||||||
 | 
					conda activate llm
 | 
				
			||||||
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
 | 
					# you can install specific ipex/torch version for your need
 | 
				
			||||||
 | 
					pip install --pre --upgrade bigdl-llm[xpu_2.1] -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
				
			||||||
 | 
					# configures OneAPI environment variables
 | 
				
			||||||
 | 
					source /opt/intel/oneapi/setvars.sh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1.2 Build and install patched version of Intel Extension for PyTorch (IPEX)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					conda activate llm
 | 
				
			||||||
 | 
					source /opt/intel/oneapi/setvars.sh
 | 
				
			||||||
 | 
					git clone https://github.com/intel/intel-extension-for-pytorch.git
 | 
				
			||||||
 | 
					cd intel-extension-for-pytorch
 | 
				
			||||||
 | 
					git checkout v2.1.10+xpu
 | 
				
			||||||
 | 
					git submodule update --init --recursive
 | 
				
			||||||
 | 
					git cherry-pick be8ea24078d8a271e53d2946ac533383f7a2aa78
 | 
				
			||||||
 | 
					export USE_AOT_DEVLIST='ats-m150,pvc'
 | 
				
			||||||
 | 
					python setup.py install
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. Run tensor parallel inference on multiple GPUs
 | 
				
			||||||
 | 
					Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 3. Run
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					For optimal performance on Arc, it is recommended to set several environment variables.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					export USE_XETLA=OFF
 | 
				
			||||||
 | 
					export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Arguments info:
 | 
				
			||||||
 | 
					- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
 | 
				
			||||||
 | 
					- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
 | 
				
			||||||
 | 
					- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Sample Output
 | 
				
			||||||
 | 
					#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
 | 
				
			||||||
 | 
					```log
 | 
				
			||||||
 | 
					Inference time: xxxx s
 | 
				
			||||||
 | 
					-------------------- Prompt --------------------
 | 
				
			||||||
 | 
					<s>[INST] <<SYS>>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<</SYS>>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					What is AI? [/INST]
 | 
				
			||||||
 | 
					-------------------- Output --------------------
 | 
				
			||||||
 | 
					[INST] <<SYS>>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<</SYS>>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					What is AI? [/INST]  Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence,
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
							
								
								
									
										116
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,116 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import intel_extension_for_pytorch as ipex
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
 | 
					from transformers import LlamaTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# you could tune the prompt based on your own model,
 | 
				
			||||||
 | 
					# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
				
			||||||
 | 
					DEFAULT_SYSTEM_PROMPT = """\
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_prompt(message: str, chat_history: list[tuple[str, str]],
 | 
				
			||||||
 | 
					               system_prompt: str) -> str:
 | 
				
			||||||
 | 
					    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
 | 
				
			||||||
 | 
					    # The first user input is _not_ stripped
 | 
				
			||||||
 | 
					    do_strip = False
 | 
				
			||||||
 | 
					    for user_input, response in chat_history:
 | 
				
			||||||
 | 
					        user_input = user_input.strip() if do_strip else user_input
 | 
				
			||||||
 | 
					        do_strip = True
 | 
				
			||||||
 | 
					        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
 | 
				
			||||||
 | 
					    message = message.strip() if do_strip else message
 | 
				
			||||||
 | 
					    texts.append(f'{message} [/INST]')
 | 
				
			||||||
 | 
					    return ''.join(texts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
				
			||||||
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
				
			||||||
 | 
					                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
 | 
				
			||||||
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
 | 
					    parser.add_argument('--prompt', type=str, default="What is AI?",
 | 
				
			||||||
 | 
					                        help='Prompt to infer')
 | 
				
			||||||
 | 
					    parser.add_argument('--n-predict', type=int, default=32,
 | 
				
			||||||
 | 
					                        help='Max tokens to predict')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load model in 4 bit,
 | 
				
			||||||
 | 
					    # which convert the relevant layers in the model into INT4 format
 | 
				
			||||||
 | 
					    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
 | 
					                                                 load_in_4bit=True,
 | 
				
			||||||
 | 
					                                                 optimize_model=True,
 | 
				
			||||||
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                 use_cache=True)
 | 
				
			||||||
 | 
					    first_half = ['model.embed_tokens', 'model.layers.0', 'model.layers.1', 'model.layers.2',
 | 
				
			||||||
 | 
					                  'model.layers.3', 'model.layers.4', 'model.layers.5', 'model.layers.6',
 | 
				
			||||||
 | 
					                  'model.layers.7', 'model.layers.8', 'model.layers.9', 'model.layers.10',
 | 
				
			||||||
 | 
					                  'model.layers.11', 'model.layers.12', 'model.layers.13', 'model.layers.14',
 | 
				
			||||||
 | 
					                  'model.layers.15']
 | 
				
			||||||
 | 
					    second_half = ['model.layers.16', 'model.layers.17', 'model.layers.18', 'model.layers.19',
 | 
				
			||||||
 | 
					                   'model.layers.20', 'model.layers.21', 'model.layers.22', 'model.layers.23',
 | 
				
			||||||
 | 
					                   'model.layers.24', 'model.layers.25', 'model.layers.26', 'model.layers.27',
 | 
				
			||||||
 | 
					                   'model.layers.28', 'model.layers.29', 'model.layers.30', 'model.layers.31',
 | 
				
			||||||
 | 
					                   'model.norm', 'lm_head']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device_map=({key: 'xpu:0' for key in first_half})
 | 
				
			||||||
 | 
					    device_map.update({key: 'xpu:1' for key in second_half})
 | 
				
			||||||
 | 
					    from accelerate import dispatch_model
 | 
				
			||||||
 | 
					    model = dispatch_model(
 | 
				
			||||||
 | 
					        model,
 | 
				
			||||||
 | 
					        device_map=device_map,
 | 
				
			||||||
 | 
					        offload_dir=None,
 | 
				
			||||||
 | 
					        skip_keys=["past_key_value", "past_key_values"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load tokenizer
 | 
				
			||||||
 | 
					    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Generate predicted tokens
 | 
				
			||||||
 | 
					    with torch.inference_mode():
 | 
				
			||||||
 | 
					        prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
				
			||||||
 | 
					        input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0')
 | 
				
			||||||
 | 
					        # ipex model needs a warmup, then inference time can be accurate
 | 
				
			||||||
 | 
					        output = model.generate(input_ids,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict)
 | 
				
			||||||
 | 
					        output = model.generate(input_ids,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # start inference
 | 
				
			||||||
 | 
					        st = time.time()
 | 
				
			||||||
 | 
					        # if your selected model is capable of utilizing previous key/value attentions
 | 
				
			||||||
 | 
					        # to enhance decoding speed, but has `"use_cache": false` in its model config,
 | 
				
			||||||
 | 
					        # it is important to set `use_cache=True` explicitly in the `generate` function
 | 
				
			||||||
 | 
					        # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
				
			||||||
 | 
					        output = model.generate(input_ids,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict)
 | 
				
			||||||
 | 
					        torch.xpu.synchronize()
 | 
				
			||||||
 | 
					        end = time.time()
 | 
				
			||||||
 | 
					        output = output.cpu()
 | 
				
			||||||
 | 
					        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
				
			||||||
 | 
					        print(f'Inference time: {end-st} s')
 | 
				
			||||||
 | 
					        print('-'*20, 'Prompt', '-'*20)
 | 
				
			||||||
 | 
					        print(prompt)
 | 
				
			||||||
 | 
					        print('-'*20, 'Output', '-'*20)
 | 
				
			||||||
 | 
					        print(output_str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -770,6 +770,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
					    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
				
			||||||
    from bigdl.llm.transformers.models.llama import llama_mlp_forward
 | 
					    from bigdl.llm.transformers.models.llama import llama_mlp_forward
 | 
				
			||||||
    from bigdl.llm.transformers.models.llama import llama_decoder_forward
 | 
					    from bigdl.llm.transformers.models.llama import llama_decoder_forward
 | 
				
			||||||
 | 
					    from bigdl.llm.transformers.models.llama import llama_model_forward
 | 
				
			||||||
    from transformers.modeling_utils import PreTrainedModel
 | 
					    from transformers.modeling_utils import PreTrainedModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # All huggingface format models are inherited from `PreTrainedModel`
 | 
					    # All huggingface format models are inherited from `PreTrainedModel`
 | 
				
			||||||
| 
						 | 
					@ -823,6 +824,11 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
                    transformers.models.llama.modeling_llama.LlamaAttention,
 | 
					                    transformers.models.llama.modeling_llama.LlamaAttention,
 | 
				
			||||||
                    llama_attention_selective_batching_forward_4_31,
 | 
					                    llama_attention_selective_batching_forward_4_31,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                convert_forward(
 | 
				
			||||||
 | 
					                    model,
 | 
				
			||||||
 | 
					                    transformers.models.llama.modeling_llama.LlamaModel,
 | 
				
			||||||
 | 
					                    llama_model_forward)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # todo implement 4.28.0 ~ 4.30.2
 | 
					        # todo implement 4.28.0 ~ 4.30.2
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
| 
						 | 
					@ -1058,6 +1064,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
            from bigdl.llm.transformers.models.qwen import qwen_attention_forward
 | 
					            from bigdl.llm.transformers.models.qwen import qwen_attention_forward
 | 
				
			||||||
            from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
 | 
					            from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
 | 
				
			||||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
					            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
				
			||||||
 | 
					            from bigdl.llm.transformers.models.qwen import qwen_model_forward
 | 
				
			||||||
            convert_forward(model,
 | 
					            convert_forward(model,
 | 
				
			||||||
                            module.QWenAttention,
 | 
					                            module.QWenAttention,
 | 
				
			||||||
                            qwen_attention_forward
 | 
					                            qwen_attention_forward
 | 
				
			||||||
| 
						 | 
					@ -1068,6 +1075,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
            convert_forward(model,
 | 
					            convert_forward(model,
 | 
				
			||||||
                            module.QWenMLP,
 | 
					                            module.QWenMLP,
 | 
				
			||||||
                            qwen_mlp_forward)
 | 
					                            qwen_mlp_forward)
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.QWenModel,
 | 
				
			||||||
 | 
					                            qwen_model_forward)
 | 
				
			||||||
    elif model.config.model_type == "qwen2":
 | 
					    elif model.config.model_type == "qwen2":
 | 
				
			||||||
        # for Qwen1.5-7B
 | 
					        # for Qwen1.5-7B
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -53,10 +53,15 @@ from transformers.models.llama.modeling_llama import LlamaModel
 | 
				
			||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
 | 
					from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
 | 
				
			||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    from transformers.cache_utils import Cache
 | 
					    from transformers.cache_utils import Cache, DynamicCache
 | 
				
			||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    Cache = Tuple[torch.Tensor]
 | 
					    Cache = Tuple[torch.Tensor]
 | 
				
			||||||
 | 
					from transformers import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
					def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
				
			||||||
| 
						 | 
					@ -106,7 +111,7 @@ def llama_model_forward_4_36(
 | 
				
			||||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
					    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
				
			||||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
					        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
				
			||||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
					            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
				
			||||||
    return LlamaModel.forward(
 | 
					    return llama_model_forward_4_36_internal(
 | 
				
			||||||
        self=self,
 | 
					        self=self,
 | 
				
			||||||
        input_ids=input_ids,
 | 
					        input_ids=input_ids,
 | 
				
			||||||
        attention_mask=attention_mask,
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
| 
						 | 
					@ -1605,3 +1610,311 @@ def llama_attention_fast_forward(
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def llama_model_forward_4_36_internal(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
				
			||||||
 | 
					    output_attentions = output_attentions if output_attentions is not None else \
 | 
				
			||||||
 | 
					        self.config.output_attentions
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None else
 | 
				
			||||||
 | 
					        self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # retrieve input_ids and inputs_embeds
 | 
				
			||||||
 | 
					    if input_ids is not None and inputs_embeds is not None:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "You cannot specify both input_ids and inputs_embeds at the same time")
 | 
				
			||||||
 | 
					    elif input_ids is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length = input_ids.shape[:2]
 | 
				
			||||||
 | 
					    elif inputs_embeds is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length = inputs_embeds.shape[:2]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_key_values_length = 0
 | 
				
			||||||
 | 
					    if use_cache:
 | 
				
			||||||
 | 
					        use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
				
			||||||
 | 
					        if use_legacy_cache:
 | 
				
			||||||
 | 
					            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
				
			||||||
 | 
					        position_ids = torch.arange(
 | 
				
			||||||
 | 
					            past_key_values_length, seq_length + past_key_values_length,
 | 
				
			||||||
 | 
					            dtype=torch.long, device=device
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        position_ids = position_ids.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self._use_flash_attention_2:
 | 
				
			||||||
 | 
					        # 2d mask is passed through the layers
 | 
				
			||||||
 | 
					        attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \
 | 
				
			||||||
 | 
					            else None
 | 
				
			||||||
 | 
					    elif self._use_sdpa and not output_attentions:
 | 
				
			||||||
 | 
					        # output_attentions=True can not be supported when using SDPA, and we fall back on
 | 
				
			||||||
 | 
					        # the manual implementation that requires a 4D causal mask in all cases.
 | 
				
			||||||
 | 
					        from transformers.models.llama.modeling_llama import \
 | 
				
			||||||
 | 
					            _prepare_4d_causal_attention_mask_for_sdpa
 | 
				
			||||||
 | 
					        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 | 
				
			||||||
 | 
					            attention_mask,
 | 
				
			||||||
 | 
					            (batch_size, seq_length),
 | 
				
			||||||
 | 
					            inputs_embeds,
 | 
				
			||||||
 | 
					            past_key_values_length,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # 4d mask is passed through the layers
 | 
				
			||||||
 | 
					        from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
 | 
				
			||||||
 | 
					        attention_mask = _prepare_4d_causal_attention_mask(
 | 
				
			||||||
 | 
					            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            logger.warning_once(
 | 
				
			||||||
 | 
					                "`use_cache=True` is incompatible with gradient checkpointing."
 | 
				
			||||||
 | 
					                " Setting `use_cache=False`..."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for decoder_layer in self.layers:
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					            layer_outputs = self._gradient_checkpointing_func(
 | 
				
			||||||
 | 
					                decoder_layer.__call__,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask,
 | 
				
			||||||
 | 
					                position_ids,
 | 
				
			||||||
 | 
					                past_key_values,
 | 
				
			||||||
 | 
					                output_attentions,
 | 
				
			||||||
 | 
					                use_cache,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # bigdl-llm changes:
 | 
				
			||||||
 | 
					            curr_device = decoder_layer.input_layernorm.weight.device
 | 
				
			||||||
 | 
					            if attention_mask is not None:
 | 
				
			||||||
 | 
					                attention_mask = attention_mask.to(curr_device)
 | 
				
			||||||
 | 
					            if position_ids is not None:
 | 
				
			||||||
 | 
					                position_ids = position_ids.to(curr_device)
 | 
				
			||||||
 | 
					            # bigdl-llm changes end
 | 
				
			||||||
 | 
					            layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                position_ids=position_ids,
 | 
				
			||||||
 | 
					                past_key_value=past_key_values,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    next_cache = None
 | 
				
			||||||
 | 
					    if use_cache:
 | 
				
			||||||
 | 
					        next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache \
 | 
				
			||||||
 | 
					            else next_decoder_cache
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache,
 | 
				
			||||||
 | 
					                                 all_hidden_states, all_self_attns] if v is not None)
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def llama_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
				
			||||||
 | 
					    output_attentions = output_attentions if output_attentions is not None \
 | 
				
			||||||
 | 
					        else self.config.output_attentions
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None else
 | 
				
			||||||
 | 
					        self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # retrieve input_ids and inputs_embeds
 | 
				
			||||||
 | 
					    if input_ids is not None and inputs_embeds is not None:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "You cannot specify both input_ids and inputs_embeds at the same time")
 | 
				
			||||||
 | 
					    elif input_ids is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length = input_ids.shape
 | 
				
			||||||
 | 
					    elif inputs_embeds is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length, _ = inputs_embeds.shape
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seq_length_with_past = seq_length
 | 
				
			||||||
 | 
					    past_key_values_length = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if past_key_values is not None:
 | 
				
			||||||
 | 
					        past_key_values_length = past_key_values[0][0].shape[2]
 | 
				
			||||||
 | 
					        seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
				
			||||||
 | 
					        position_ids = torch.arange(
 | 
				
			||||||
 | 
					            past_key_values_length, seq_length + past_key_values_length,
 | 
				
			||||||
 | 
					            dtype=torch.long, device=device
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        position_ids = position_ids.view(-1, seq_length).long()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    if attention_mask is None:
 | 
				
			||||||
 | 
					        attention_mask = torch.ones(
 | 
				
			||||||
 | 
					            (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        padding_mask = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if 0 in attention_mask:
 | 
				
			||||||
 | 
					            padding_mask = attention_mask
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            padding_mask = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attention_mask = self._prepare_decoder_attention_mask(
 | 
				
			||||||
 | 
					        attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            logger.warning_once(
 | 
				
			||||||
 | 
					                "`use_cache=True` is incompatible with gradient checkpointing."
 | 
				
			||||||
 | 
					                " Setting `use_cache=False`..."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = () if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for idx, decoder_layer in enumerate(self.layers):
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def create_custom_forward(module):
 | 
				
			||||||
 | 
					                def custom_forward(*inputs):
 | 
				
			||||||
 | 
					                    # None for past_key_value
 | 
				
			||||||
 | 
					                    return module(*inputs, past_key_value, output_attentions,
 | 
				
			||||||
 | 
					                                  padding_mask=padding_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                return custom_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            layer_outputs = torch.utils.checkpoint.checkpoint(
 | 
				
			||||||
 | 
					                create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # bigdl-llm changes:
 | 
				
			||||||
 | 
					            #
 | 
				
			||||||
 | 
					            # Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
 | 
				
			||||||
 | 
					            #
 | 
				
			||||||
 | 
					            # When the model is partitioned on two different devices using
 | 
				
			||||||
 | 
					            # `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
 | 
				
			||||||
 | 
					            # added to each layer's `forward``, which will result in moving `attention_mask`
 | 
				
			||||||
 | 
					            # and `position_ids`, which allocated on device:0, to other devices for each
 | 
				
			||||||
 | 
					            # decoder layer not in device:0.
 | 
				
			||||||
 | 
					            #
 | 
				
			||||||
 | 
					            # To avoid this, we move `attention_mask` and `position_ids` to the device of
 | 
				
			||||||
 | 
					            # the current layer before the forward call, so that the moving is only done once
 | 
				
			||||||
 | 
					            # for each devices other than devie:0.
 | 
				
			||||||
 | 
					            #
 | 
				
			||||||
 | 
					            curr_device = decoder_layer.input_layernorm.weight.device
 | 
				
			||||||
 | 
					            if attention_mask is not None:
 | 
				
			||||||
 | 
					                attention_mask = attention_mask.to(curr_device)
 | 
				
			||||||
 | 
					            if position_ids is not None:
 | 
				
			||||||
 | 
					                position_ids = position_ids.to(curr_device)
 | 
				
			||||||
 | 
					            # bigdl-llm changes end
 | 
				
			||||||
 | 
					            layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                position_ids=position_ids,
 | 
				
			||||||
 | 
					                past_key_value=past_key_value,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					                padding_mask=padding_mask,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    next_cache = next_decoder_cache if use_cache else None
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache,
 | 
				
			||||||
 | 
					                                 all_hidden_states, all_self_attns] if v is not None)
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,6 +45,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
					from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
 | 
					from bigdl.llm.utils.common import invalidInputError, invalidOperationError
 | 
				
			||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
apply_rotary_emb_func = None
 | 
					apply_rotary_emb_func = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -544,3 +545,210 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
            SILU, qtype
 | 
					            SILU, qtype
 | 
				
			||||||
        ))
 | 
					        ))
 | 
				
			||||||
    return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
 | 
					    return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def qwen_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    token_type_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    head_mask: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    encoder_hidden_states: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    encoder_attention_mask: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    output_attentions = (
 | 
				
			||||||
 | 
					        output_attentions
 | 
				
			||||||
 | 
					        if output_attentions is not None
 | 
				
			||||||
 | 
					        else self.config.output_attentions
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states
 | 
				
			||||||
 | 
					        if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = (
 | 
				
			||||||
 | 
					        return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if input_ids is not None and inputs_embeds is not None:
 | 
				
			||||||
 | 
					        invalidInputError(
 | 
				
			||||||
 | 
					            False,
 | 
				
			||||||
 | 
					            "You cannot specify both input_ids and inputs_embeds at the same time"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    elif input_ids is not None:
 | 
				
			||||||
 | 
					        input_shape = input_ids.size()
 | 
				
			||||||
 | 
					        input_ids = input_ids.view(-1, input_shape[-1])
 | 
				
			||||||
 | 
					        batch_size = input_ids.shape[0]
 | 
				
			||||||
 | 
					    elif inputs_embeds is not None:
 | 
				
			||||||
 | 
					        input_shape = inputs_embeds.size()[:-1]
 | 
				
			||||||
 | 
					        batch_size = inputs_embeds.shape[0]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if token_type_ids is not None:
 | 
				
			||||||
 | 
					        token_type_ids = token_type_ids.view(-1, input_shape[-1])
 | 
				
			||||||
 | 
					    if position_ids is not None:
 | 
				
			||||||
 | 
					        position_ids = position_ids.view(-1, input_shape[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if past_key_values is None:
 | 
				
			||||||
 | 
					        past_length = 0
 | 
				
			||||||
 | 
					        past_key_values = tuple([None] * len(self.h))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if self.use_cache_quantization:
 | 
				
			||||||
 | 
					            past_length = past_key_values[0][0][0].size(2)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            past_length = past_key_values[0][0].size(-2)
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        position_ids = torch.arange(
 | 
				
			||||||
 | 
					            past_length,
 | 
				
			||||||
 | 
					            input_shape[-1] + past_length,
 | 
				
			||||||
 | 
					            dtype=torch.long,
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attention_mask is not None:
 | 
				
			||||||
 | 
					        if batch_size <= 0:
 | 
				
			||||||
 | 
					            invalidInputError(False, "batch_size has to be defined and > 0")
 | 
				
			||||||
 | 
					        attention_mask = attention_mask.view(batch_size, -1)
 | 
				
			||||||
 | 
					        attention_mask = attention_mask[:, None, None, :]
 | 
				
			||||||
 | 
					        attention_mask = attention_mask.to(dtype=self.dtype)
 | 
				
			||||||
 | 
					        attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    encoder_attention_mask = None
 | 
				
			||||||
 | 
					    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.wte(input_ids)
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kv_seq_len = hidden_states.size()[1]
 | 
				
			||||||
 | 
					    if past_key_values[0] is not None:
 | 
				
			||||||
 | 
					        # past key values[0][0] shape: bs * seq_len * head_num * dim
 | 
				
			||||||
 | 
					        if self.use_cache_quantization:
 | 
				
			||||||
 | 
					            kv_seq_len += past_key_values[0][0][0].shape[2]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kv_seq_len += past_key_values[0][0].shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.training or not self.use_dynamic_ntk:
 | 
				
			||||||
 | 
					        ntk_alpha_list = [1.0]
 | 
				
			||||||
 | 
					    elif kv_seq_len != hidden_states.size()[1]:
 | 
				
			||||||
 | 
					        ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        ntk_alpha_list = []
 | 
				
			||||||
 | 
					        if attention_mask is not None and kv_seq_len > self.seq_length:
 | 
				
			||||||
 | 
					            true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1,
 | 
				
			||||||
 | 
					                                                                           dtype=torch.int32)
 | 
				
			||||||
 | 
					            for i in range(hidden_states.size()[0]):
 | 
				
			||||||
 | 
					                true_seq_len = true_seq_lens[i].item()
 | 
				
			||||||
 | 
					                ntk_alpha = self.get_ntk_alpha(true_seq_len)
 | 
				
			||||||
 | 
					                ntk_alpha_list.append(ntk_alpha)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            ntk_alpha = self.get_ntk_alpha(kv_seq_len)
 | 
				
			||||||
 | 
					            ntk_alpha_list.append(ntk_alpha)
 | 
				
			||||||
 | 
					    self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
 | 
				
			||||||
 | 
					    rotary_pos_emb_list = [
 | 
				
			||||||
 | 
					        self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.drop(hidden_states)
 | 
				
			||||||
 | 
					    output_shape = input_shape + (hidden_states.size(-1),)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            logger.warning_once(
 | 
				
			||||||
 | 
					                "`use_cache=True` is incompatible with gradient checkpointing. "
 | 
				
			||||||
 | 
					                "Setting `use_cache=False`..."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    presents = () if use_cache else None
 | 
				
			||||||
 | 
					    all_self_attentions = () if output_attentions else None
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states = all_hidden_states + (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def create_custom_forward(module):
 | 
				
			||||||
 | 
					                def custom_forward(*inputs):
 | 
				
			||||||
 | 
					                    # None for past_key_value
 | 
				
			||||||
 | 
					                    return module(*inputs, use_cache, output_attentions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                return custom_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            outputs = torch.utils.checkpoint.checkpoint(
 | 
				
			||||||
 | 
					                create_custom_forward(block),
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                rotary_pos_emb_list,
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					                attention_mask,
 | 
				
			||||||
 | 
					                head_mask[i],
 | 
				
			||||||
 | 
					                encoder_hidden_states,
 | 
				
			||||||
 | 
					                encoder_attention_mask,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # bigdl-llm changes
 | 
				
			||||||
 | 
					            curr_device = block.ln_1.weight.device
 | 
				
			||||||
 | 
					            from accelerate.utils.operations import send_to_device
 | 
				
			||||||
 | 
					            if rotary_pos_emb_list is not None:
 | 
				
			||||||
 | 
					                rotary_pos_emb_list = send_to_device(rotary_pos_emb_list, curr_device)
 | 
				
			||||||
 | 
					            if attention_mask is not None:
 | 
				
			||||||
 | 
					                attention_mask = send_to_device(attention_mask, curr_device)
 | 
				
			||||||
 | 
					            if head_mask[i] is not None:
 | 
				
			||||||
 | 
					                head_mask[i] = send_to_device(head_mask[i], curr_device)
 | 
				
			||||||
 | 
					            if encoder_hidden_states is not None:
 | 
				
			||||||
 | 
					                encoder_hidden_states = send_to_device(encoder_hidden_states, curr_device)
 | 
				
			||||||
 | 
					            if encoder_attention_mask is not None:
 | 
				
			||||||
 | 
					                encoder_attention_mask = send_to_device(encoder_attention_mask,
 | 
				
			||||||
 | 
					                                                        curr_device)
 | 
				
			||||||
 | 
					            # bigdl-llm changes ends
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            outputs = block(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                layer_past=layer_past,
 | 
				
			||||||
 | 
					                rotary_pos_emb_list=rotary_pos_emb_list,
 | 
				
			||||||
 | 
					                attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                head_mask=head_mask[i],
 | 
				
			||||||
 | 
					                encoder_hidden_states=encoder_hidden_states,
 | 
				
			||||||
 | 
					                encoder_attention_mask=encoder_attention_mask,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = outputs[0]
 | 
				
			||||||
 | 
					        if use_cache is True:
 | 
				
			||||||
 | 
					            presents = presents + (outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.ln_f(hidden_states)
 | 
				
			||||||
 | 
					    hidden_states = hidden_states.view(output_shape)
 | 
				
			||||||
 | 
					    # Add last hidden state
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states = all_hidden_states + (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(
 | 
				
			||||||
 | 
					            v for v in [hidden_states, presents, all_hidden_states] if v is not None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=presents,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attentions,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -54,7 +54,19 @@ from bigdl.llm.transformers.kv import DynamicFp8Cache
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
					from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
				
			||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
 | 
					from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
 | 
				
			||||||
 | 
					from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
 | 
				
			||||||
 | 
					from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    from transformers.cache_utils import Cache, DynamicCache
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    Cache = Tuple[torch.Tensor]
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					from transformers import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -82,7 +94,7 @@ def qwen2_model_forward(
 | 
				
			||||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
					    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
				
			||||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
					        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
				
			||||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
					            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
				
			||||||
    return Qwen2Model.forward(
 | 
					    return qwen2_model_forward_internal(
 | 
				
			||||||
        self=self,
 | 
					        self=self,
 | 
				
			||||||
        input_ids=input_ids,
 | 
					        input_ids=input_ids,
 | 
				
			||||||
        attention_mask=attention_mask,
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
| 
						 | 
					@ -96,6 +108,173 @@ def qwen2_model_forward(
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def qwen2_model_forward_internal(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
				
			||||||
 | 
					    output_attentions = output_attentions if output_attentions is not None else \
 | 
				
			||||||
 | 
					        self.config.output_attentions
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None else
 | 
				
			||||||
 | 
					        self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # retrieve input_ids and inputs_embeds
 | 
				
			||||||
 | 
					    if input_ids is not None and inputs_embeds is not None:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "You cannot specify both decoder_input_ids and "
 | 
				
			||||||
 | 
					                          "decoder_inputs_embeds at the same time")
 | 
				
			||||||
 | 
					    elif input_ids is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length = input_ids.shape
 | 
				
			||||||
 | 
					    elif inputs_embeds is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length, _ = inputs_embeds.shape
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        invalidInputError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            logger.warning_once(
 | 
				
			||||||
 | 
					                "`use_cache=True` is incompatible with gradient checkpointing. "
 | 
				
			||||||
 | 
					                "Setting `use_cache=False`..."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_key_values_length = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if use_cache:
 | 
				
			||||||
 | 
					        use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
				
			||||||
 | 
					        if use_legacy_cache:
 | 
				
			||||||
 | 
					            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
				
			||||||
 | 
					        position_ids = torch.arange(
 | 
				
			||||||
 | 
					            past_key_values_length, seq_length + past_key_values_length,
 | 
				
			||||||
 | 
					            dtype=torch.long, device=device
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        position_ids = position_ids.view(-1, seq_length).long()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    flash_attn_2 = self._attn_implementation == "flash_attention_2"
 | 
				
			||||||
 | 
					    if attention_mask is not None and flash_attn_2 and use_cache:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        is_padding_right = attention_mask[:, -1].sum().item() != batch_size
 | 
				
			||||||
 | 
					        if is_padding_right:
 | 
				
			||||||
 | 
					            invalidInputError(
 | 
				
			||||||
 | 
					                False,
 | 
				
			||||||
 | 
					                "You are attempting to perform batched generation with padding_side='right'"
 | 
				
			||||||
 | 
					                " this may lead to unexpected behaviour for Flash Attention version of Qwen2."
 | 
				
			||||||
 | 
					                " Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing "
 | 
				
			||||||
 | 
					                "the input. "
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self._attn_implementation == "flash_attention_2":
 | 
				
			||||||
 | 
					        # 2d mask is passed through the layers
 | 
				
			||||||
 | 
					        attention_mask = attention_mask if (attention_mask is not None and
 | 
				
			||||||
 | 
					                                            0 in attention_mask) else None
 | 
				
			||||||
 | 
					    elif self._attn_implementation == "sdpa" and not output_attentions:
 | 
				
			||||||
 | 
					        # output_attentions=True can not be supported when using SDPA, and we fall back on
 | 
				
			||||||
 | 
					        # the manual implementation that requires a 4D causal mask in all cases.
 | 
				
			||||||
 | 
					        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 | 
				
			||||||
 | 
					            attention_mask,
 | 
				
			||||||
 | 
					            (batch_size, seq_length),
 | 
				
			||||||
 | 
					            inputs_embeds,
 | 
				
			||||||
 | 
					            past_key_values_length,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # 4d mask is passed through the layers
 | 
				
			||||||
 | 
					        attention_mask = _prepare_4d_causal_attention_mask(
 | 
				
			||||||
 | 
					            attention_mask,
 | 
				
			||||||
 | 
					            (batch_size, seq_length),
 | 
				
			||||||
 | 
					            inputs_embeds,
 | 
				
			||||||
 | 
					            past_key_values_length,
 | 
				
			||||||
 | 
					            sliding_window=self.config.sliding_window,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for decoder_layer in self.layers:
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					            layer_outputs = self._gradient_checkpointing_func(
 | 
				
			||||||
 | 
					                decoder_layer.__call__,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask,
 | 
				
			||||||
 | 
					                position_ids,
 | 
				
			||||||
 | 
					                past_key_values,
 | 
				
			||||||
 | 
					                output_attentions,
 | 
				
			||||||
 | 
					                use_cache,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # bigdl-llm changes
 | 
				
			||||||
 | 
					            curr_device = decoder_layer.input_layernorm.weight.device
 | 
				
			||||||
 | 
					            if attention_mask is not None:
 | 
				
			||||||
 | 
					                attention_mask = attention_mask.to(curr_device)
 | 
				
			||||||
 | 
					            if position_ids is not None:
 | 
				
			||||||
 | 
					                position_ids = position_ids.to(curr_device)
 | 
				
			||||||
 | 
					            # bigdl-llm changes end
 | 
				
			||||||
 | 
					            layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                position_ids=position_ids,
 | 
				
			||||||
 | 
					                past_key_value=past_key_values,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    next_cache = None
 | 
				
			||||||
 | 
					    if use_cache:
 | 
				
			||||||
 | 
					        next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \
 | 
				
			||||||
 | 
					            next_decoder_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache,
 | 
				
			||||||
 | 
					                                 all_hidden_states, all_self_attns] if v is not None)
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def qwen2_attention_forward(
 | 
					def qwen2_attention_forward(
 | 
				
			||||||
    self,
 | 
					    self,
 | 
				
			||||||
    hidden_states: torch.Tensor,
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue