Support running pipeline parallel inference by vertically partitioning model to different devices (#10392)
* support pipeline parallel inference * fix logging * remove benchmark file * fic * need to warmup twice * support qwen and qwen2 * fix lint * remove genxir * refine
This commit is contained in:
parent
66b4bb5c5d
commit
9e763b049c
6 changed files with 907 additions and 3 deletions
78
python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Normal file
78
python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Run BigDL-LLM on Multiple Intel GPUs in pipeline parallel fashion
|
||||
|
||||
This example demonstrates how to run BigDL-LLM optimized low-bit model vertically partitioned on two [Intel GPUs](../README.md).
|
||||
|
||||
## Requirements
|
||||
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
|
||||
|
||||
## Example:
|
||||
|
||||
### 1.1 Install BigDL-LLM
|
||||
|
||||
```bash
|
||||
conda create -n llm python=3.9
|
||||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
# you can install specific ipex/torch version for your need
|
||||
pip install --pre --upgrade bigdl-llm[xpu_2.1] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
# configures OneAPI environment variables
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||
```
|
||||
|
||||
### 1.2 Build and install patched version of Intel Extension for PyTorch (IPEX)
|
||||
|
||||
```bash
|
||||
conda activate llm
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
git clone https://github.com/intel/intel-extension-for-pytorch.git
|
||||
cd intel-extension-for-pytorch
|
||||
git checkout v2.1.10+xpu
|
||||
git submodule update --init --recursive
|
||||
git cherry-pick be8ea24078d8a271e53d2946ac533383f7a2aa78
|
||||
export USE_AOT_DEVLIST='ats-m150,pvc'
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
|
||||
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
|
||||
|
||||
### 2. Run tensor parallel inference on multiple GPUs
|
||||
Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device:
|
||||
|
||||
### 3. Run
|
||||
|
||||
For optimal performance on Arc, it is recommended to set several environment variables.
|
||||
|
||||
```bash
|
||||
export USE_XETLA=OFF
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
```
|
||||
|
||||
```
|
||||
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
|
||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
|
||||
#### Sample Output
|
||||
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Prompt --------------------
|
||||
<s>[INST] <<SYS>>
|
||||
|
||||
<</SYS>>
|
||||
|
||||
What is AI? [/INST]
|
||||
-------------------- Output --------------------
|
||||
[INST] <<SYS>>
|
||||
|
||||
<</SYS>>
|
||||
|
||||
What is AI? [/INST] Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence,
|
||||
```
|
||||
116
python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
Normal file
116
python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
|
||||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
# you could tune the prompt based on your own model,
|
||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
"""
|
||||
|
||||
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
||||
system_prompt: str) -> str:
|
||||
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
||||
# The first user input is _not_ stripped
|
||||
do_strip = False
|
||||
for user_input, response in chat_history:
|
||||
user_input = user_input.strip() if do_strip else user_input
|
||||
do_strip = True
|
||||
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
||||
message = message.strip() if do_strip else message
|
||||
texts.append(f'{message} [/INST]')
|
||||
return ''.join(texts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
|
||||
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
|
||||
# Load model in 4 bit,
|
||||
# which convert the relevant layers in the model into INT4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
first_half = ['model.embed_tokens', 'model.layers.0', 'model.layers.1', 'model.layers.2',
|
||||
'model.layers.3', 'model.layers.4', 'model.layers.5', 'model.layers.6',
|
||||
'model.layers.7', 'model.layers.8', 'model.layers.9', 'model.layers.10',
|
||||
'model.layers.11', 'model.layers.12', 'model.layers.13', 'model.layers.14',
|
||||
'model.layers.15']
|
||||
second_half = ['model.layers.16', 'model.layers.17', 'model.layers.18', 'model.layers.19',
|
||||
'model.layers.20', 'model.layers.21', 'model.layers.22', 'model.layers.23',
|
||||
'model.layers.24', 'model.layers.25', 'model.layers.26', 'model.layers.27',
|
||||
'model.layers.28', 'model.layers.29', 'model.layers.30', 'model.layers.31',
|
||||
'model.norm', 'lm_head']
|
||||
|
||||
device_map=({key: 'xpu:0' for key in first_half})
|
||||
device_map.update({key: 'xpu:1' for key in second_half})
|
||||
from accelerate import dispatch_model
|
||||
model = dispatch_model(
|
||||
model,
|
||||
device_map=device_map,
|
||||
offload_dir=None,
|
||||
skip_keys=["past_key_value", "past_key_values"],
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Generate predicted tokens
|
||||
with torch.inference_mode():
|
||||
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0')
|
||||
# ipex model needs a warmup, then inference time can be accurate
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
|
||||
# start inference
|
||||
st = time.time()
|
||||
# if your selected model is capable of utilizing previous key/value attentions
|
||||
# to enhance decoding speed, but has `"use_cache": false` in its model config,
|
||||
# it is important to set `use_cache=True` explicitly in the `generate` function
|
||||
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
torch.xpu.synchronize()
|
||||
end = time.time()
|
||||
output = output.cpu()
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
print(f'Inference time: {end-st} s')
|
||||
print('-'*20, 'Prompt', '-'*20)
|
||||
print(prompt)
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
|
||||
|
|
@ -770,6 +770,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
||||
from bigdl.llm.transformers.models.llama import llama_decoder_forward
|
||||
from bigdl.llm.transformers.models.llama import llama_model_forward
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
# All huggingface format models are inherited from `PreTrainedModel`
|
||||
|
|
@ -823,6 +824,11 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_selective_batching_forward_4_31,
|
||||
)
|
||||
else:
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaModel,
|
||||
llama_model_forward)
|
||||
else:
|
||||
# todo implement 4.28.0 ~ 4.30.2
|
||||
pass
|
||||
|
|
@ -1058,6 +1064,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
|
||||
from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||
from bigdl.llm.transformers.models.qwen import qwen_model_forward
|
||||
convert_forward(model,
|
||||
module.QWenAttention,
|
||||
qwen_attention_forward
|
||||
|
|
@ -1068,6 +1075,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.QWenMLP,
|
||||
qwen_mlp_forward)
|
||||
convert_forward(model,
|
||||
module.QWenModel,
|
||||
qwen_model_forward)
|
||||
elif model.config.model_type == "qwen2":
|
||||
# for Qwen1.5-7B
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
|
|
@ -53,10 +53,15 @@ from transformers.models.llama.modeling_llama import LlamaModel
|
|||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
try:
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
except ImportError:
|
||||
Cache = Tuple[torch.Tensor]
|
||||
from transformers import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
|
@ -106,7 +111,7 @@ def llama_model_forward_4_36(
|
|||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
return LlamaModel.forward(
|
||||
return llama_model_forward_4_36_internal(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
|
@ -1605,3 +1610,311 @@ def llama_attention_fast_forward(
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def llama_model_forward_4_36_internal(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else \
|
||||
self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False,
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length,
|
||||
dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \
|
||||
else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
from transformers.models.llama.modeling_llama import \
|
||||
_prepare_4d_causal_attention_mask_for_sdpa
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing."
|
||||
" Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
# bigdl-llm changes:
|
||||
curr_device = decoder_layer.input_layernorm.weight.device
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(curr_device)
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.to(curr_device)
|
||||
# bigdl-llm changes end
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache \
|
||||
else next_decoder_cache
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache,
|
||||
all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False,
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length,
|
||||
dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing."
|
||||
" Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, past_key_value, output_attentions,
|
||||
padding_mask=padding_mask)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
|
||||
)
|
||||
else:
|
||||
# bigdl-llm changes:
|
||||
#
|
||||
# Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
|
||||
#
|
||||
# When the model is partitioned on two different devices using
|
||||
# `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
|
||||
# added to each layer's `forward``, which will result in moving `attention_mask`
|
||||
# and `position_ids`, which allocated on device:0, to other devices for each
|
||||
# decoder layer not in device:0.
|
||||
#
|
||||
# To avoid this, we move `attention_mask` and `position_ids` to the device of
|
||||
# the current layer before the forward call, so that the moving is only done once
|
||||
# for each devices other than devie:0.
|
||||
#
|
||||
curr_device = decoder_layer.input_layernorm.weight.device
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(curr_device)
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.to(curr_device)
|
||||
# bigdl-llm changes end
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache,
|
||||
all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_
|
|||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
apply_rotary_emb_func = None
|
||||
|
||||
|
|
@ -544,3 +545,210 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|||
SILU, qtype
|
||||
))
|
||||
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
||||
|
||||
|
||||
def qwen_model_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(
|
||||
False,
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
if self.use_cache_quantization:
|
||||
past_length = past_key_values[0][0][0].size(2)
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
invalidInputError(False, "batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=self.dtype)
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
encoder_attention_mask = None
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
kv_seq_len = hidden_states.size()[1]
|
||||
if past_key_values[0] is not None:
|
||||
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
||||
if self.use_cache_quantization:
|
||||
kv_seq_len += past_key_values[0][0][0].shape[2]
|
||||
else:
|
||||
kv_seq_len += past_key_values[0][0].shape[1]
|
||||
|
||||
if self.training or not self.use_dynamic_ntk:
|
||||
ntk_alpha_list = [1.0]
|
||||
elif kv_seq_len != hidden_states.size()[1]:
|
||||
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
|
||||
else:
|
||||
ntk_alpha_list = []
|
||||
if attention_mask is not None and kv_seq_len > self.seq_length:
|
||||
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1,
|
||||
dtype=torch.int32)
|
||||
for i in range(hidden_states.size()[0]):
|
||||
true_seq_len = true_seq_lens[i].item()
|
||||
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
||||
ntk_alpha_list.append(ntk_alpha)
|
||||
else:
|
||||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||
ntk_alpha_list.append(ntk_alpha)
|
||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||
rotary_pos_emb_list = [
|
||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
||||
]
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. "
|
||||
"Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
rotary_pos_emb_list,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
# bigdl-llm changes
|
||||
curr_device = block.ln_1.weight.device
|
||||
from accelerate.utils.operations import send_to_device
|
||||
if rotary_pos_emb_list is not None:
|
||||
rotary_pos_emb_list = send_to_device(rotary_pos_emb_list, curr_device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = send_to_device(attention_mask, curr_device)
|
||||
if head_mask[i] is not None:
|
||||
head_mask[i] = send_to_device(head_mask[i], curr_device)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = send_to_device(encoder_hidden_states, curr_device)
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = send_to_device(encoder_attention_mask,
|
||||
curr_device)
|
||||
# bigdl-llm changes ends
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [hidden_states, presents, all_hidden_states] if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,19 @@ from bigdl.llm.transformers.kv import DynamicFp8Cache
|
|||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
try:
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
except ImportError:
|
||||
Cache = Tuple[torch.Tensor]
|
||||
import logging
|
||||
from transformers import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
|
@ -82,7 +94,7 @@ def qwen2_model_forward(
|
|||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
return Qwen2Model.forward(
|
||||
return qwen2_model_forward_internal(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
|
@ -96,6 +108,173 @@ def qwen2_model_forward(
|
|||
)
|
||||
|
||||
|
||||
def qwen2_model_forward_internal(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else \
|
||||
self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False,
|
||||
"You cannot specify both decoder_input_ids and "
|
||||
"decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
invalidInputError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. "
|
||||
"Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length,
|
||||
dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
flash_attn_2 = self._attn_implementation == "flash_attention_2"
|
||||
if attention_mask is not None and flash_attn_2 and use_cache:
|
||||
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
invalidInputError(
|
||||
False,
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2."
|
||||
" Make sure to call `tokenizer.padding_side = 'left'` before tokenizing "
|
||||
"the input. "
|
||||
)
|
||||
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and
|
||||
0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
# bigdl-llm changes
|
||||
curr_device = decoder_layer.input_layernorm.weight.device
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(curr_device)
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.to(curr_device)
|
||||
# bigdl-llm changes end
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \
|
||||
next_decoder_cache
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache,
|
||||
all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def qwen2_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue