Support running pipeline parallel inference by vertically partitioning model to different devices (#10392)
* support pipeline parallel inference * fix logging * remove benchmark file * fic * need to warmup twice * support qwen and qwen2 * fix lint * remove genxir * refine
This commit is contained in:
parent
66b4bb5c5d
commit
9e763b049c
6 changed files with 907 additions and 3 deletions
78
python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Normal file
78
python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Run BigDL-LLM on Multiple Intel GPUs in pipeline parallel fashion
|
||||||
|
|
||||||
|
This example demonstrates how to run BigDL-LLM optimized low-bit model vertically partitioned on two [Intel GPUs](../README.md).
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
|
||||||
|
|
||||||
|
## Example:
|
||||||
|
|
||||||
|
### 1.1 Install BigDL-LLM
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n llm python=3.9
|
||||||
|
conda activate llm
|
||||||
|
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||||
|
# you can install specific ipex/torch version for your need
|
||||||
|
pip install --pre --upgrade bigdl-llm[xpu_2.1] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||||
|
# configures OneAPI environment variables
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
|
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 Build and install patched version of Intel Extension for PyTorch (IPEX)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda activate llm
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
git clone https://github.com/intel/intel-extension-for-pytorch.git
|
||||||
|
cd intel-extension-for-pytorch
|
||||||
|
git checkout v2.1.10+xpu
|
||||||
|
git submodule update --init --recursive
|
||||||
|
git cherry-pick be8ea24078d8a271e53d2946ac533383f7a2aa78
|
||||||
|
export USE_AOT_DEVLIST='ats-m150,pvc'
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
|
||||||
|
|
||||||
|
### 2. Run tensor parallel inference on multiple GPUs
|
||||||
|
Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device:
|
||||||
|
|
||||||
|
### 3. Run
|
||||||
|
|
||||||
|
For optimal performance on Arc, it is recommended to set several environment variables.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_XETLA=OFF
|
||||||
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments info:
|
||||||
|
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
|
||||||
|
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||||
|
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||||
|
|
||||||
|
#### Sample Output
|
||||||
|
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||||
|
```log
|
||||||
|
Inference time: xxxx s
|
||||||
|
-------------------- Prompt --------------------
|
||||||
|
<s>[INST] <<SYS>>
|
||||||
|
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
What is AI? [/INST]
|
||||||
|
-------------------- Output --------------------
|
||||||
|
[INST] <<SYS>>
|
||||||
|
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
What is AI? [/INST] Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence,
|
||||||
|
```
|
||||||
116
python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
Normal file
116
python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
|
# you could tune the prompt based on your own model,
|
||||||
|
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """\
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
||||||
|
system_prompt: str) -> str:
|
||||||
|
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
||||||
|
# The first user input is _not_ stripped
|
||||||
|
do_strip = False
|
||||||
|
for user_input, response in chat_history:
|
||||||
|
user_input = user_input.strip() if do_strip else user_input
|
||||||
|
do_strip = True
|
||||||
|
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
||||||
|
message = message.strip() if do_strip else message
|
||||||
|
texts.append(f'{message} [/INST]')
|
||||||
|
return ''.join(texts)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||||
|
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
|
||||||
|
', or the path to the huggingface checkpoint folder')
|
||||||
|
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||||
|
help='Prompt to infer')
|
||||||
|
parser.add_argument('--n-predict', type=int, default=32,
|
||||||
|
help='Max tokens to predict')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_path = args.repo_id_or_model_path
|
||||||
|
|
||||||
|
# Load model in 4 bit,
|
||||||
|
# which convert the relevant layers in the model into INT4 format
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
|
load_in_4bit=True,
|
||||||
|
optimize_model=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_cache=True)
|
||||||
|
first_half = ['model.embed_tokens', 'model.layers.0', 'model.layers.1', 'model.layers.2',
|
||||||
|
'model.layers.3', 'model.layers.4', 'model.layers.5', 'model.layers.6',
|
||||||
|
'model.layers.7', 'model.layers.8', 'model.layers.9', 'model.layers.10',
|
||||||
|
'model.layers.11', 'model.layers.12', 'model.layers.13', 'model.layers.14',
|
||||||
|
'model.layers.15']
|
||||||
|
second_half = ['model.layers.16', 'model.layers.17', 'model.layers.18', 'model.layers.19',
|
||||||
|
'model.layers.20', 'model.layers.21', 'model.layers.22', 'model.layers.23',
|
||||||
|
'model.layers.24', 'model.layers.25', 'model.layers.26', 'model.layers.27',
|
||||||
|
'model.layers.28', 'model.layers.29', 'model.layers.30', 'model.layers.31',
|
||||||
|
'model.norm', 'lm_head']
|
||||||
|
|
||||||
|
device_map=({key: 'xpu:0' for key in first_half})
|
||||||
|
device_map.update({key: 'xpu:1' for key in second_half})
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
model = dispatch_model(
|
||||||
|
model,
|
||||||
|
device_map=device_map,
|
||||||
|
offload_dir=None,
|
||||||
|
skip_keys=["past_key_value", "past_key_values"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Generate predicted tokens
|
||||||
|
with torch.inference_mode():
|
||||||
|
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
||||||
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0')
|
||||||
|
# ipex model needs a warmup, then inference time can be accurate
|
||||||
|
output = model.generate(input_ids,
|
||||||
|
max_new_tokens=args.n_predict)
|
||||||
|
output = model.generate(input_ids,
|
||||||
|
max_new_tokens=args.n_predict)
|
||||||
|
|
||||||
|
# start inference
|
||||||
|
st = time.time()
|
||||||
|
# if your selected model is capable of utilizing previous key/value attentions
|
||||||
|
# to enhance decoding speed, but has `"use_cache": false` in its model config,
|
||||||
|
# it is important to set `use_cache=True` explicitly in the `generate` function
|
||||||
|
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
||||||
|
output = model.generate(input_ids,
|
||||||
|
max_new_tokens=args.n_predict)
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
output = output.cpu()
|
||||||
|
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
print(f'Inference time: {end-st} s')
|
||||||
|
print('-'*20, 'Prompt', '-'*20)
|
||||||
|
print(prompt)
|
||||||
|
print('-'*20, 'Output', '-'*20)
|
||||||
|
print(output_str)
|
||||||
|
|
||||||
|
|
@ -770,6 +770,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||||
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
||||||
from bigdl.llm.transformers.models.llama import llama_decoder_forward
|
from bigdl.llm.transformers.models.llama import llama_decoder_forward
|
||||||
|
from bigdl.llm.transformers.models.llama import llama_model_forward
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
# All huggingface format models are inherited from `PreTrainedModel`
|
# All huggingface format models are inherited from `PreTrainedModel`
|
||||||
|
|
@ -823,6 +824,11 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
llama_attention_selective_batching_forward_4_31,
|
llama_attention_selective_batching_forward_4_31,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel,
|
||||||
|
llama_model_forward)
|
||||||
else:
|
else:
|
||||||
# todo implement 4.28.0 ~ 4.30.2
|
# todo implement 4.28.0 ~ 4.30.2
|
||||||
pass
|
pass
|
||||||
|
|
@ -1058,6 +1064,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
|
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
|
||||||
from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
|
from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
|
||||||
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||||
|
from bigdl.llm.transformers.models.qwen import qwen_model_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.QWenAttention,
|
module.QWenAttention,
|
||||||
qwen_attention_forward
|
qwen_attention_forward
|
||||||
|
|
@ -1068,6 +1075,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.QWenMLP,
|
module.QWenMLP,
|
||||||
qwen_mlp_forward)
|
qwen_mlp_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.QWenModel,
|
||||||
|
qwen_model_forward)
|
||||||
elif model.config.model_type == "qwen2":
|
elif model.config.model_type == "qwen2":
|
||||||
# for Qwen1.5-7B
|
# for Qwen1.5-7B
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -53,10 +53,15 @@ from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Cache = Tuple[torch.Tensor]
|
Cache = Tuple[torch.Tensor]
|
||||||
|
from transformers import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
@ -106,7 +111,7 @@ def llama_model_forward_4_36(
|
||||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return LlamaModel.forward(
|
return llama_model_forward_4_36_internal(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
|
@ -1605,3 +1610,311 @@ def llama_attention_fast_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_forward_4_36_internal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else \
|
||||||
|
self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else
|
||||||
|
self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if use_cache:
|
||||||
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \
|
||||||
|
else None
|
||||||
|
elif self._use_sdpa and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
from transformers.models.llama.modeling_llama import \
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing."
|
||||||
|
" Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# bigdl-llm changes:
|
||||||
|
curr_device = decoder_layer.input_layernorm.weight.device
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(curr_device)
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.to(curr_device)
|
||||||
|
# bigdl-llm changes end
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache \
|
||||||
|
else next_decoder_cache
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache,
|
||||||
|
all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None \
|
||||||
|
else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else
|
||||||
|
self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
padding_mask = None
|
||||||
|
else:
|
||||||
|
if 0 in attention_mask:
|
||||||
|
padding_mask = attention_mask
|
||||||
|
else:
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing."
|
||||||
|
" Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, past_key_value, output_attentions,
|
||||||
|
padding_mask=padding_mask)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# bigdl-llm changes:
|
||||||
|
#
|
||||||
|
# Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
|
||||||
|
#
|
||||||
|
# When the model is partitioned on two different devices using
|
||||||
|
# `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
|
||||||
|
# added to each layer's `forward``, which will result in moving `attention_mask`
|
||||||
|
# and `position_ids`, which allocated on device:0, to other devices for each
|
||||||
|
# decoder layer not in device:0.
|
||||||
|
#
|
||||||
|
# To avoid this, we move `attention_mask` and `position_ids` to the device of
|
||||||
|
# the current layer before the forward call, so that the moving is only done once
|
||||||
|
# for each devices other than devie:0.
|
||||||
|
#
|
||||||
|
curr_device = decoder_layer.input_layernorm.weight.device
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(curr_device)
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.to(curr_device)
|
||||||
|
# bigdl-llm changes end
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache,
|
||||||
|
all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
||||||
apply_rotary_emb_func = None
|
apply_rotary_emb_func = None
|
||||||
|
|
||||||
|
|
@ -544,3 +545,210 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
SILU, qtype
|
SILU, qtype
|
||||||
))
|
))
|
||||||
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_length = 0
|
||||||
|
past_key_values = tuple([None] * len(self.h))
|
||||||
|
else:
|
||||||
|
if self.use_cache_quantization:
|
||||||
|
past_length = past_key_values[0][0][0].size(2)
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_length,
|
||||||
|
input_shape[-1] + past_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if batch_size <= 0:
|
||||||
|
invalidInputError(False, "batch_size has to be defined and > 0")
|
||||||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
attention_mask = attention_mask.to(dtype=self.dtype)
|
||||||
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
kv_seq_len = hidden_states.size()[1]
|
||||||
|
if past_key_values[0] is not None:
|
||||||
|
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
||||||
|
if self.use_cache_quantization:
|
||||||
|
kv_seq_len += past_key_values[0][0][0].shape[2]
|
||||||
|
else:
|
||||||
|
kv_seq_len += past_key_values[0][0].shape[1]
|
||||||
|
|
||||||
|
if self.training or not self.use_dynamic_ntk:
|
||||||
|
ntk_alpha_list = [1.0]
|
||||||
|
elif kv_seq_len != hidden_states.size()[1]:
|
||||||
|
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
|
||||||
|
else:
|
||||||
|
ntk_alpha_list = []
|
||||||
|
if attention_mask is not None and kv_seq_len > self.seq_length:
|
||||||
|
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1,
|
||||||
|
dtype=torch.int32)
|
||||||
|
for i in range(hidden_states.size()[0]):
|
||||||
|
true_seq_len = true_seq_lens[i].item()
|
||||||
|
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
||||||
|
ntk_alpha_list.append(ntk_alpha)
|
||||||
|
else:
|
||||||
|
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||||
|
ntk_alpha_list.append(ntk_alpha)
|
||||||
|
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||||
|
rotary_pos_emb_list = [
|
||||||
|
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
||||||
|
]
|
||||||
|
|
||||||
|
hidden_states = self.drop(hidden_states)
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. "
|
||||||
|
"Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
presents = () if use_cache else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, use_cache, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
rotary_pos_emb_list,
|
||||||
|
None,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[i],
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# bigdl-llm changes
|
||||||
|
curr_device = block.ln_1.weight.device
|
||||||
|
from accelerate.utils.operations import send_to_device
|
||||||
|
if rotary_pos_emb_list is not None:
|
||||||
|
rotary_pos_emb_list = send_to_device(rotary_pos_emb_list, curr_device)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = send_to_device(attention_mask, curr_device)
|
||||||
|
if head_mask[i] is not None:
|
||||||
|
head_mask[i] = send_to_device(head_mask[i], curr_device)
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states = send_to_device(encoder_hidden_states, curr_device)
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = send_to_device(encoder_attention_mask,
|
||||||
|
curr_device)
|
||||||
|
# bigdl-llm changes ends
|
||||||
|
|
||||||
|
outputs = block(
|
||||||
|
hidden_states,
|
||||||
|
layer_past=layer_past,
|
||||||
|
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if use_cache is True:
|
||||||
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
# Add last hidden state
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v for v in [hidden_states, presents, all_hidden_states] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=presents,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,19 @@ from bigdl.llm.transformers.kv import DynamicFp8Cache
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
except ImportError:
|
||||||
|
Cache = Tuple[torch.Tensor]
|
||||||
|
import logging
|
||||||
|
from transformers import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
@ -82,7 +94,7 @@ def qwen2_model_forward(
|
||||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return Qwen2Model.forward(
|
return qwen2_model_forward_internal(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
|
@ -96,6 +108,173 @@ def qwen2_model_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_model_forward_internal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else \
|
||||||
|
self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else
|
||||||
|
self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"You cannot specify both decoder_input_ids and "
|
||||||
|
"decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
invalidInputError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. "
|
||||||
|
"Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
flash_attn_2 = self._attn_implementation == "flash_attention_2"
|
||||||
|
if attention_mask is not None and flash_attn_2 and use_cache:
|
||||||
|
|
||||||
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
|
if is_padding_right:
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
"You are attempting to perform batched generation with padding_side='right'"
|
||||||
|
" this may lead to unexpected behaviour for Flash Attention version of Qwen2."
|
||||||
|
" Make sure to call `tokenizer.padding_side = 'left'` before tokenizing "
|
||||||
|
"the input. "
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._attn_implementation == "flash_attention_2":
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and
|
||||||
|
0 in attention_mask) else None
|
||||||
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# bigdl-llm changes
|
||||||
|
curr_device = decoder_layer.input_layernorm.weight.device
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(curr_device)
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.to(curr_device)
|
||||||
|
# bigdl-llm changes end
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \
|
||||||
|
next_decoder_cache
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache,
|
||||||
|
all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def qwen2_attention_forward(
|
def qwen2_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue