Support running pipeline parallel inference by vertically partitioning model to different devices (#10392)
* support pipeline parallel inference * fix logging * remove benchmark file * fic * need to warmup twice * support qwen and qwen2 * fix lint * remove genxir * refine
This commit is contained in:
		
							parent
							
								
									66b4bb5c5d
								
							
						
					
					
						commit
						9e763b049c
					
				
					 6 changed files with 907 additions and 3 deletions
				
			
		
							
								
								
									
										78
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,78 @@
 | 
			
		|||
# Run BigDL-LLM on Multiple Intel GPUs in pipeline parallel fashion
 | 
			
		||||
 | 
			
		||||
This example demonstrates how to run BigDL-LLM optimized low-bit model vertically partitioned on two [Intel GPUs](../README.md).
 | 
			
		||||
 | 
			
		||||
## Requirements
 | 
			
		||||
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
 | 
			
		||||
 | 
			
		||||
## Example:
 | 
			
		||||
 | 
			
		||||
### 1.1 Install BigDL-LLM
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
conda create -n llm python=3.9
 | 
			
		||||
conda activate llm
 | 
			
		||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
			
		||||
# you can install specific ipex/torch version for your need
 | 
			
		||||
pip install --pre --upgrade bigdl-llm[xpu_2.1] -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
			
		||||
# configures OneAPI environment variables
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
 | 
			
		||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 1.2 Build and install patched version of Intel Extension for PyTorch (IPEX)
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
conda activate llm
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
git clone https://github.com/intel/intel-extension-for-pytorch.git
 | 
			
		||||
cd intel-extension-for-pytorch
 | 
			
		||||
git checkout v2.1.10+xpu
 | 
			
		||||
git submodule update --init --recursive
 | 
			
		||||
git cherry-pick be8ea24078d8a271e53d2946ac533383f7a2aa78
 | 
			
		||||
export USE_AOT_DEVLIST='ats-m150,pvc'
 | 
			
		||||
python setup.py install
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
 | 
			
		||||
 | 
			
		||||
### 2. Run tensor parallel inference on multiple GPUs
 | 
			
		||||
Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device:
 | 
			
		||||
 | 
			
		||||
### 3. Run
 | 
			
		||||
 | 
			
		||||
For optimal performance on Arc, it is recommended to set several environment variables.
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
export USE_XETLA=OFF
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Arguments info:
 | 
			
		||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
 | 
			
		||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
 | 
			
		||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
 | 
			
		||||
 | 
			
		||||
#### Sample Output
 | 
			
		||||
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
 | 
			
		||||
```log
 | 
			
		||||
Inference time: xxxx s
 | 
			
		||||
-------------------- Prompt --------------------
 | 
			
		||||
<s>[INST] <<SYS>>
 | 
			
		||||
 | 
			
		||||
<</SYS>>
 | 
			
		||||
 | 
			
		||||
What is AI? [/INST]
 | 
			
		||||
-------------------- Output --------------------
 | 
			
		||||
[INST] <<SYS>>
 | 
			
		||||
 | 
			
		||||
<</SYS>>
 | 
			
		||||
 | 
			
		||||
What is AI? [/INST]  Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence,
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										116
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,116 @@
 | 
			
		|||
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
import time
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
 | 
			
		||||
# you could tune the prompt based on your own model,
 | 
			
		||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
			
		||||
DEFAULT_SYSTEM_PROMPT = """\
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
def get_prompt(message: str, chat_history: list[tuple[str, str]],
 | 
			
		||||
               system_prompt: str) -> str:
 | 
			
		||||
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
 | 
			
		||||
    # The first user input is _not_ stripped
 | 
			
		||||
    do_strip = False
 | 
			
		||||
    for user_input, response in chat_history:
 | 
			
		||||
        user_input = user_input.strip() if do_strip else user_input
 | 
			
		||||
        do_strip = True
 | 
			
		||||
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
 | 
			
		||||
    message = message.strip() if do_strip else message
 | 
			
		||||
    texts.append(f'{message} [/INST]')
 | 
			
		||||
    return ''.join(texts)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
			
		||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
			
		||||
                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
 | 
			
		||||
                             ', or the path to the huggingface checkpoint folder')
 | 
			
		||||
    parser.add_argument('--prompt', type=str, default="What is AI?",
 | 
			
		||||
                        help='Prompt to infer')
 | 
			
		||||
    parser.add_argument('--n-predict', type=int, default=32,
 | 
			
		||||
                        help='Max tokens to predict')
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
 | 
			
		||||
    # Load model in 4 bit,
 | 
			
		||||
    # which convert the relevant layers in the model into INT4 format
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
                                                 load_in_4bit=True,
 | 
			
		||||
                                                 optimize_model=True,
 | 
			
		||||
                                                 trust_remote_code=True,
 | 
			
		||||
                                                 use_cache=True)
 | 
			
		||||
    first_half = ['model.embed_tokens', 'model.layers.0', 'model.layers.1', 'model.layers.2',
 | 
			
		||||
                  'model.layers.3', 'model.layers.4', 'model.layers.5', 'model.layers.6',
 | 
			
		||||
                  'model.layers.7', 'model.layers.8', 'model.layers.9', 'model.layers.10',
 | 
			
		||||
                  'model.layers.11', 'model.layers.12', 'model.layers.13', 'model.layers.14',
 | 
			
		||||
                  'model.layers.15']
 | 
			
		||||
    second_half = ['model.layers.16', 'model.layers.17', 'model.layers.18', 'model.layers.19',
 | 
			
		||||
                   'model.layers.20', 'model.layers.21', 'model.layers.22', 'model.layers.23',
 | 
			
		||||
                   'model.layers.24', 'model.layers.25', 'model.layers.26', 'model.layers.27',
 | 
			
		||||
                   'model.layers.28', 'model.layers.29', 'model.layers.30', 'model.layers.31',
 | 
			
		||||
                   'model.norm', 'lm_head']
 | 
			
		||||
 | 
			
		||||
    device_map=({key: 'xpu:0' for key in first_half})
 | 
			
		||||
    device_map.update({key: 'xpu:1' for key in second_half})
 | 
			
		||||
    from accelerate import dispatch_model
 | 
			
		||||
    model = dispatch_model(
 | 
			
		||||
        model,
 | 
			
		||||
        device_map=device_map,
 | 
			
		||||
        offload_dir=None,
 | 
			
		||||
        skip_keys=["past_key_value", "past_key_values"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    # Generate predicted tokens
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
			
		||||
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0')
 | 
			
		||||
        # ipex model needs a warmup, then inference time can be accurate
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict)
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict)
 | 
			
		||||
 | 
			
		||||
        # start inference
 | 
			
		||||
        st = time.time()
 | 
			
		||||
        # if your selected model is capable of utilizing previous key/value attentions
 | 
			
		||||
        # to enhance decoding speed, but has `"use_cache": false` in its model config,
 | 
			
		||||
        # it is important to set `use_cache=True` explicitly in the `generate` function
 | 
			
		||||
        # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict)
 | 
			
		||||
        torch.xpu.synchronize()
 | 
			
		||||
        end = time.time()
 | 
			
		||||
        output = output.cpu()
 | 
			
		||||
        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
			
		||||
        print(f'Inference time: {end-st} s')
 | 
			
		||||
        print('-'*20, 'Prompt', '-'*20)
 | 
			
		||||
        print(prompt)
 | 
			
		||||
        print('-'*20, 'Output', '-'*20)
 | 
			
		||||
        print(output_str)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -770,6 +770,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_mlp_forward
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_decoder_forward
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_model_forward
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
 | 
			
		||||
    # All huggingface format models are inherited from `PreTrainedModel`
 | 
			
		||||
| 
						 | 
				
			
			@ -823,6 +824,11 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                    transformers.models.llama.modeling_llama.LlamaAttention,
 | 
			
		||||
                    llama_attention_selective_batching_forward_4_31,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                convert_forward(
 | 
			
		||||
                    model,
 | 
			
		||||
                    transformers.models.llama.modeling_llama.LlamaModel,
 | 
			
		||||
                    llama_model_forward)
 | 
			
		||||
    else:
 | 
			
		||||
        # todo implement 4.28.0 ~ 4.30.2
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			@ -1058,6 +1064,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            from bigdl.llm.transformers.models.qwen import qwen_attention_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.qwen import qwen_model_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.QWenAttention,
 | 
			
		||||
                            qwen_attention_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -1068,6 +1075,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            convert_forward(model,
 | 
			
		||||
                            module.QWenMLP,
 | 
			
		||||
                            qwen_mlp_forward)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.QWenModel,
 | 
			
		||||
                            qwen_model_forward)
 | 
			
		||||
    elif model.config.model_type == "qwen2":
 | 
			
		||||
        # for Qwen1.5-7B
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -53,10 +53,15 @@ from transformers.models.llama.modeling_llama import LlamaModel
 | 
			
		|||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from transformers.cache_utils import Cache
 | 
			
		||||
    from transformers.cache_utils import Cache, DynamicCache
 | 
			
		||||
except ImportError:
 | 
			
		||||
    Cache = Tuple[torch.Tensor]
 | 
			
		||||
from transformers import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			@ -106,7 +111,7 @@ def llama_model_forward_4_36(
 | 
			
		|||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
			
		||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
    return LlamaModel.forward(
 | 
			
		||||
    return llama_model_forward_4_36_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -1605,3 +1610,311 @@ def llama_attention_fast_forward(
 | 
			
		|||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_model_forward_4_36_internal(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    output_attentions = output_attentions if output_attentions is not None else \
 | 
			
		||||
        self.config.output_attentions
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None else
 | 
			
		||||
        self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # retrieve input_ids and inputs_embeds
 | 
			
		||||
    if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          "You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
    elif input_ids is not None:
 | 
			
		||||
        batch_size, seq_length = input_ids.shape[:2]
 | 
			
		||||
    elif inputs_embeds is not None:
 | 
			
		||||
        batch_size, seq_length = inputs_embeds.shape[:2]
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
    past_key_values_length = 0
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
			
		||||
        if use_legacy_cache:
 | 
			
		||||
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
			
		||||
        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
			
		||||
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
        position_ids = torch.arange(
 | 
			
		||||
            past_key_values_length, seq_length + past_key_values_length,
 | 
			
		||||
            dtype=torch.long, device=device
 | 
			
		||||
        )
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
 | 
			
		||||
    if self._use_flash_attention_2:
 | 
			
		||||
        # 2d mask is passed through the layers
 | 
			
		||||
        attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \
 | 
			
		||||
            else None
 | 
			
		||||
    elif self._use_sdpa and not output_attentions:
 | 
			
		||||
        # output_attentions=True can not be supported when using SDPA, and we fall back on
 | 
			
		||||
        # the manual implementation that requires a 4D causal mask in all cases.
 | 
			
		||||
        from transformers.models.llama.modeling_llama import \
 | 
			
		||||
            _prepare_4d_causal_attention_mask_for_sdpa
 | 
			
		||||
        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            (batch_size, seq_length),
 | 
			
		||||
            inputs_embeds,
 | 
			
		||||
            past_key_values_length,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        # 4d mask is passed through the layers
 | 
			
		||||
        from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
 | 
			
		||||
        attention_mask = _prepare_4d_causal_attention_mask(
 | 
			
		||||
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # embed positions
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "`use_cache=True` is incompatible with gradient checkpointing."
 | 
			
		||||
                " Setting `use_cache=False`..."
 | 
			
		||||
            )
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
    # decoder layers
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    all_self_attns = () if output_attentions else None
 | 
			
		||||
    next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
    for decoder_layer in self.layers:
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            layer_outputs = self._gradient_checkpointing_func(
 | 
			
		||||
                decoder_layer.__call__,
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                position_ids,
 | 
			
		||||
                past_key_values,
 | 
			
		||||
                output_attentions,
 | 
			
		||||
                use_cache,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # bigdl-llm changes:
 | 
			
		||||
            curr_device = decoder_layer.input_layernorm.weight.device
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = attention_mask.to(curr_device)
 | 
			
		||||
            if position_ids is not None:
 | 
			
		||||
                position_ids = position_ids.to(curr_device)
 | 
			
		||||
            # bigdl-llm changes end
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_values,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # add hidden states from the last decoder layer
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = None
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache \
 | 
			
		||||
            else next_decoder_cache
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(v for v in [hidden_states, next_cache,
 | 
			
		||||
                                 all_hidden_states, all_self_attns] if v is not None)
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=next_cache,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    output_attentions = output_attentions if output_attentions is not None \
 | 
			
		||||
        else self.config.output_attentions
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None else
 | 
			
		||||
        self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # retrieve input_ids and inputs_embeds
 | 
			
		||||
    if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          "You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
    elif input_ids is not None:
 | 
			
		||||
        batch_size, seq_length = input_ids.shape
 | 
			
		||||
    elif inputs_embeds is not None:
 | 
			
		||||
        batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
    seq_length_with_past = seq_length
 | 
			
		||||
    past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
    if past_key_values is not None:
 | 
			
		||||
        past_key_values_length = past_key_values[0][0].shape[2]
 | 
			
		||||
        seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
			
		||||
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
        position_ids = torch.arange(
 | 
			
		||||
            past_key_values_length, seq_length + past_key_values_length,
 | 
			
		||||
            dtype=torch.long, device=device
 | 
			
		||||
        )
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
    else:
 | 
			
		||||
        position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
    # embed positions
 | 
			
		||||
    if attention_mask is None:
 | 
			
		||||
        attention_mask = torch.ones(
 | 
			
		||||
            (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 | 
			
		||||
        )
 | 
			
		||||
        padding_mask = None
 | 
			
		||||
    else:
 | 
			
		||||
        if 0 in attention_mask:
 | 
			
		||||
            padding_mask = attention_mask
 | 
			
		||||
        else:
 | 
			
		||||
            padding_mask = None
 | 
			
		||||
 | 
			
		||||
    attention_mask = self._prepare_decoder_attention_mask(
 | 
			
		||||
        attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "`use_cache=True` is incompatible with gradient checkpointing."
 | 
			
		||||
                " Setting `use_cache=False`..."
 | 
			
		||||
            )
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
    # decoder layers
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    all_self_attns = () if output_attentions else None
 | 
			
		||||
    next_decoder_cache = () if use_cache else None
 | 
			
		||||
 | 
			
		||||
    for idx, decoder_layer in enumerate(self.layers):
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
 | 
			
		||||
            def create_custom_forward(module):
 | 
			
		||||
                def custom_forward(*inputs):
 | 
			
		||||
                    # None for past_key_value
 | 
			
		||||
                    return module(*inputs, past_key_value, output_attentions,
 | 
			
		||||
                                  padding_mask=padding_mask)
 | 
			
		||||
 | 
			
		||||
                return custom_forward
 | 
			
		||||
 | 
			
		||||
            layer_outputs = torch.utils.checkpoint.checkpoint(
 | 
			
		||||
                create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # bigdl-llm changes:
 | 
			
		||||
            #
 | 
			
		||||
            # Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
 | 
			
		||||
            #
 | 
			
		||||
            # When the model is partitioned on two different devices using
 | 
			
		||||
            # `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
 | 
			
		||||
            # added to each layer's `forward``, which will result in moving `attention_mask`
 | 
			
		||||
            # and `position_ids`, which allocated on device:0, to other devices for each
 | 
			
		||||
            # decoder layer not in device:0.
 | 
			
		||||
            #
 | 
			
		||||
            # To avoid this, we move `attention_mask` and `position_ids` to the device of
 | 
			
		||||
            # the current layer before the forward call, so that the moving is only done once
 | 
			
		||||
            # for each devices other than devie:0.
 | 
			
		||||
            #
 | 
			
		||||
            curr_device = decoder_layer.input_layernorm.weight.device
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = attention_mask.to(curr_device)
 | 
			
		||||
            if position_ids is not None:
 | 
			
		||||
                position_ids = position_ids.to(curr_device)
 | 
			
		||||
            # bigdl-llm changes end
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_value,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
                padding_mask=padding_mask,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # add hidden states from the last decoder layer
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = next_decoder_cache if use_cache else None
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(v for v in [hidden_states, next_cache,
 | 
			
		||||
                                 all_hidden_states, all_self_attns] if v is not None)
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=next_cache,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,6 +45,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_
 | 
			
		|||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
 | 
			
		||||
apply_rotary_emb_func = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -544,3 +545,210 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		|||
            SILU, qtype
 | 
			
		||||
        ))
 | 
			
		||||
    return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
 | 
			
		||||
    attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    token_type_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    head_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    encoder_hidden_states: Optional[torch.Tensor] = None,
 | 
			
		||||
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
):
 | 
			
		||||
    output_attentions = (
 | 
			
		||||
        output_attentions
 | 
			
		||||
        if output_attentions is not None
 | 
			
		||||
        else self.config.output_attentions
 | 
			
		||||
    )
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states
 | 
			
		||||
        if output_hidden_states is not None
 | 
			
		||||
        else self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    return_dict = (
 | 
			
		||||
        return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            "You cannot specify both input_ids and inputs_embeds at the same time"
 | 
			
		||||
        )
 | 
			
		||||
    elif input_ids is not None:
 | 
			
		||||
        input_shape = input_ids.size()
 | 
			
		||||
        input_ids = input_ids.view(-1, input_shape[-1])
 | 
			
		||||
        batch_size = input_ids.shape[0]
 | 
			
		||||
    elif inputs_embeds is not None:
 | 
			
		||||
        input_shape = inputs_embeds.size()[:-1]
 | 
			
		||||
        batch_size = inputs_embeds.shape[0]
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
    device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
 | 
			
		||||
    if token_type_ids is not None:
 | 
			
		||||
        token_type_ids = token_type_ids.view(-1, input_shape[-1])
 | 
			
		||||
    if position_ids is not None:
 | 
			
		||||
        position_ids = position_ids.view(-1, input_shape[-1])
 | 
			
		||||
 | 
			
		||||
    if past_key_values is None:
 | 
			
		||||
        past_length = 0
 | 
			
		||||
        past_key_values = tuple([None] * len(self.h))
 | 
			
		||||
    else:
 | 
			
		||||
        if self.use_cache_quantization:
 | 
			
		||||
            past_length = past_key_values[0][0][0].size(2)
 | 
			
		||||
        else:
 | 
			
		||||
            past_length = past_key_values[0][0].size(-2)
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        position_ids = torch.arange(
 | 
			
		||||
            past_length,
 | 
			
		||||
            input_shape[-1] + past_length,
 | 
			
		||||
            dtype=torch.long,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        if batch_size <= 0:
 | 
			
		||||
            invalidInputError(False, "batch_size has to be defined and > 0")
 | 
			
		||||
        attention_mask = attention_mask.view(batch_size, -1)
 | 
			
		||||
        attention_mask = attention_mask[:, None, None, :]
 | 
			
		||||
        attention_mask = attention_mask.to(dtype=self.dtype)
 | 
			
		||||
        attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
 | 
			
		||||
 | 
			
		||||
    encoder_attention_mask = None
 | 
			
		||||
    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.wte(input_ids)
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = hidden_states.size()[1]
 | 
			
		||||
    if past_key_values[0] is not None:
 | 
			
		||||
        # past key values[0][0] shape: bs * seq_len * head_num * dim
 | 
			
		||||
        if self.use_cache_quantization:
 | 
			
		||||
            kv_seq_len += past_key_values[0][0][0].shape[2]
 | 
			
		||||
        else:
 | 
			
		||||
            kv_seq_len += past_key_values[0][0].shape[1]
 | 
			
		||||
 | 
			
		||||
    if self.training or not self.use_dynamic_ntk:
 | 
			
		||||
        ntk_alpha_list = [1.0]
 | 
			
		||||
    elif kv_seq_len != hidden_states.size()[1]:
 | 
			
		||||
        ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
 | 
			
		||||
    else:
 | 
			
		||||
        ntk_alpha_list = []
 | 
			
		||||
        if attention_mask is not None and kv_seq_len > self.seq_length:
 | 
			
		||||
            true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1,
 | 
			
		||||
                                                                           dtype=torch.int32)
 | 
			
		||||
            for i in range(hidden_states.size()[0]):
 | 
			
		||||
                true_seq_len = true_seq_lens[i].item()
 | 
			
		||||
                ntk_alpha = self.get_ntk_alpha(true_seq_len)
 | 
			
		||||
                ntk_alpha_list.append(ntk_alpha)
 | 
			
		||||
        else:
 | 
			
		||||
            ntk_alpha = self.get_ntk_alpha(kv_seq_len)
 | 
			
		||||
            ntk_alpha_list.append(ntk_alpha)
 | 
			
		||||
    self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
 | 
			
		||||
    rotary_pos_emb_list = [
 | 
			
		||||
        self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.drop(hidden_states)
 | 
			
		||||
    output_shape = input_shape + (hidden_states.size(-1),)
 | 
			
		||||
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "`use_cache=True` is incompatible with gradient checkpointing. "
 | 
			
		||||
                "Setting `use_cache=False`..."
 | 
			
		||||
            )
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
    presents = () if use_cache else None
 | 
			
		||||
    all_self_attentions = () if output_attentions else None
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 | 
			
		||||
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states = all_hidden_states + (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
 | 
			
		||||
            def create_custom_forward(module):
 | 
			
		||||
                def custom_forward(*inputs):
 | 
			
		||||
                    # None for past_key_value
 | 
			
		||||
                    return module(*inputs, use_cache, output_attentions)
 | 
			
		||||
 | 
			
		||||
                return custom_forward
 | 
			
		||||
 | 
			
		||||
            outputs = torch.utils.checkpoint.checkpoint(
 | 
			
		||||
                create_custom_forward(block),
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                rotary_pos_emb_list,
 | 
			
		||||
                None,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                head_mask[i],
 | 
			
		||||
                encoder_hidden_states,
 | 
			
		||||
                encoder_attention_mask,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # bigdl-llm changes
 | 
			
		||||
            curr_device = block.ln_1.weight.device
 | 
			
		||||
            from accelerate.utils.operations import send_to_device
 | 
			
		||||
            if rotary_pos_emb_list is not None:
 | 
			
		||||
                rotary_pos_emb_list = send_to_device(rotary_pos_emb_list, curr_device)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = send_to_device(attention_mask, curr_device)
 | 
			
		||||
            if head_mask[i] is not None:
 | 
			
		||||
                head_mask[i] = send_to_device(head_mask[i], curr_device)
 | 
			
		||||
            if encoder_hidden_states is not None:
 | 
			
		||||
                encoder_hidden_states = send_to_device(encoder_hidden_states, curr_device)
 | 
			
		||||
            if encoder_attention_mask is not None:
 | 
			
		||||
                encoder_attention_mask = send_to_device(encoder_attention_mask,
 | 
			
		||||
                                                        curr_device)
 | 
			
		||||
            # bigdl-llm changes ends
 | 
			
		||||
 | 
			
		||||
            outputs = block(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                layer_past=layer_past,
 | 
			
		||||
                rotary_pos_emb_list=rotary_pos_emb_list,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                head_mask=head_mask[i],
 | 
			
		||||
                encoder_hidden_states=encoder_hidden_states,
 | 
			
		||||
                encoder_attention_mask=encoder_attention_mask,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        if use_cache is True:
 | 
			
		||||
            presents = presents + (outputs[1],)
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.ln_f(hidden_states)
 | 
			
		||||
    hidden_states = hidden_states.view(output_shape)
 | 
			
		||||
    # Add last hidden state
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states = all_hidden_states + (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(
 | 
			
		||||
            v for v in [hidden_states, presents, all_hidden_states] if v is not None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=presents,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attentions,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,7 +54,19 @@ from bigdl.llm.transformers.kv import DynamicFp8Cache
 | 
			
		|||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from transformers.cache_utils import Cache, DynamicCache
 | 
			
		||||
except ImportError:
 | 
			
		||||
    Cache = Tuple[torch.Tensor]
 | 
			
		||||
import logging
 | 
			
		||||
from transformers import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -82,7 +94,7 @@ def qwen2_model_forward(
 | 
			
		|||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
			
		||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
    return Qwen2Model.forward(
 | 
			
		||||
    return qwen2_model_forward_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -96,6 +108,173 @@ def qwen2_model_forward(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_model_forward_internal(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    output_attentions = output_attentions if output_attentions is not None else \
 | 
			
		||||
        self.config.output_attentions
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None else
 | 
			
		||||
        self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # retrieve input_ids and inputs_embeds
 | 
			
		||||
    if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          "You cannot specify both decoder_input_ids and "
 | 
			
		||||
                          "decoder_inputs_embeds at the same time")
 | 
			
		||||
    elif input_ids is not None:
 | 
			
		||||
        batch_size, seq_length = input_ids.shape
 | 
			
		||||
    elif inputs_embeds is not None:
 | 
			
		||||
        batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 | 
			
		||||
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "`use_cache=True` is incompatible with gradient checkpointing. "
 | 
			
		||||
                "Setting `use_cache=False`..."
 | 
			
		||||
            )
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
    past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
			
		||||
        if use_legacy_cache:
 | 
			
		||||
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
			
		||||
        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
			
		||||
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
        position_ids = torch.arange(
 | 
			
		||||
            past_key_values_length, seq_length + past_key_values_length,
 | 
			
		||||
            dtype=torch.long, device=device
 | 
			
		||||
        )
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
    else:
 | 
			
		||||
        position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
 | 
			
		||||
    flash_attn_2 = self._attn_implementation == "flash_attention_2"
 | 
			
		||||
    if attention_mask is not None and flash_attn_2 and use_cache:
 | 
			
		||||
 | 
			
		||||
        is_padding_right = attention_mask[:, -1].sum().item() != batch_size
 | 
			
		||||
        if is_padding_right:
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                "You are attempting to perform batched generation with padding_side='right'"
 | 
			
		||||
                " this may lead to unexpected behaviour for Flash Attention version of Qwen2."
 | 
			
		||||
                " Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing "
 | 
			
		||||
                "the input. "
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if self._attn_implementation == "flash_attention_2":
 | 
			
		||||
        # 2d mask is passed through the layers
 | 
			
		||||
        attention_mask = attention_mask if (attention_mask is not None and
 | 
			
		||||
                                            0 in attention_mask) else None
 | 
			
		||||
    elif self._attn_implementation == "sdpa" and not output_attentions:
 | 
			
		||||
        # output_attentions=True can not be supported when using SDPA, and we fall back on
 | 
			
		||||
        # the manual implementation that requires a 4D causal mask in all cases.
 | 
			
		||||
        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            (batch_size, seq_length),
 | 
			
		||||
            inputs_embeds,
 | 
			
		||||
            past_key_values_length,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        # 4d mask is passed through the layers
 | 
			
		||||
        attention_mask = _prepare_4d_causal_attention_mask(
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            (batch_size, seq_length),
 | 
			
		||||
            inputs_embeds,
 | 
			
		||||
            past_key_values_length,
 | 
			
		||||
            sliding_window=self.config.sliding_window,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    # decoder layers
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    all_self_attns = () if output_attentions else None
 | 
			
		||||
    next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
    for decoder_layer in self.layers:
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            layer_outputs = self._gradient_checkpointing_func(
 | 
			
		||||
                decoder_layer.__call__,
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                position_ids,
 | 
			
		||||
                past_key_values,
 | 
			
		||||
                output_attentions,
 | 
			
		||||
                use_cache,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # bigdl-llm changes
 | 
			
		||||
            curr_device = decoder_layer.input_layernorm.weight.device
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = attention_mask.to(curr_device)
 | 
			
		||||
            if position_ids is not None:
 | 
			
		||||
                position_ids = position_ids.to(curr_device)
 | 
			
		||||
            # bigdl-llm changes end
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_values,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # add hidden states from the last decoder layer
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = None
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \
 | 
			
		||||
            next_decoder_cache
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(v for v in [hidden_states, next_cache,
 | 
			
		||||
                                 all_hidden_states, all_self_attns] if v is not None)
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=next_cache,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue