diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md new file mode 100644 index 00000000..b19275cd --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md @@ -0,0 +1,78 @@ +# Run BigDL-LLM on Multiple Intel GPUs in pipeline parallel fashion + +This example demonstrates how to run BigDL-LLM optimized low-bit model vertically partitioned on two [Intel GPUs](../README.md). + +## Requirements +To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine. + +## Example: + +### 1.1 Install BigDL-LLM + +```bash +conda create -n llm python=3.9 +conda activate llm +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +# you can install specific ipex/torch version for your need +pip install --pre --upgrade bigdl-llm[xpu_2.1] -f https://developer.intel.com/ipex-whl-stable-xpu +# configures OneAPI environment variables +source /opt/intel/oneapi/setvars.sh + +conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc +``` + +### 1.2 Build and install patched version of Intel Extension for PyTorch (IPEX) + +```bash +conda activate llm +source /opt/intel/oneapi/setvars.sh +git clone https://github.com/intel/intel-extension-for-pytorch.git +cd intel-extension-for-pytorch +git checkout v2.1.10+xpu +git submodule update --init --recursive +git cherry-pick be8ea24078d8a271e53d2946ac533383f7a2aa78 +export USE_AOT_DEVLIST='ats-m150,pvc' +python setup.py install +``` + + +> **Important**: IPEX 2.1.10+xpu requires IntelĀ® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version. + +### 2. Run tensor parallel inference on multiple GPUs +Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device: + +### 3. Run + +For optimal performance on Arc, it is recommended to set several environment variables. + +```bash +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +``` + +``` +python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +#### Sample Output +#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) +```log +Inference time: xxxx s +-------------------- Prompt -------------------- +[INST] <> + +<> + +What is AI? [/INST] +-------------------- Output -------------------- +[INST] <> + +<> + +What is AI? [/INST] Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence, +``` \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py new file mode 100644 index 00000000..4247fdcc --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py @@ -0,0 +1,116 @@ + +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import intel_extension_for_pytorch as ipex +import time +import argparse + +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import LlamaTokenizer + +# you could tune the prompt based on your own model, +# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style +DEFAULT_SYSTEM_PROMPT = """\ +""" + +def get_prompt(message: str, chat_history: list[tuple[str, str]], + system_prompt: str) -> str: + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + # The first user input is _not_ stripped + do_strip = False + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_4bit=True, + optimize_model=True, + trust_remote_code=True, + use_cache=True) + first_half = ['model.embed_tokens', 'model.layers.0', 'model.layers.1', 'model.layers.2', + 'model.layers.3', 'model.layers.4', 'model.layers.5', 'model.layers.6', + 'model.layers.7', 'model.layers.8', 'model.layers.9', 'model.layers.10', + 'model.layers.11', 'model.layers.12', 'model.layers.13', 'model.layers.14', + 'model.layers.15'] + second_half = ['model.layers.16', 'model.layers.17', 'model.layers.18', 'model.layers.19', + 'model.layers.20', 'model.layers.21', 'model.layers.22', 'model.layers.23', + 'model.layers.24', 'model.layers.25', 'model.layers.26', 'model.layers.27', + 'model.layers.28', 'model.layers.29', 'model.layers.30', 'model.layers.31', + 'model.norm', 'lm_head'] + + device_map=({key: 'xpu:0' for key in first_half}) + device_map.update({key: 'xpu:1' for key in second_half}) + from accelerate import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + offload_dir=None, + skip_keys=["past_key_value", "past_key_values"], + ) + + # Load tokenizer + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Generate predicted tokens + with torch.inference_mode(): + prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0') + # ipex model needs a warmup, then inference time can be accurate + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + + # start inference + st = time.time() + # if your selected model is capable of utilizing previous key/value attentions + # to enhance decoding speed, but has `"use_cache": false` in its model config, + # it is important to set `use_cache=True` explicitly in the `generate` function + # to obtain optimal performance with BigDL-LLM INT4 optimizations + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + torch.xpu.synchronize() + end = time.time() + output = output.cpu() + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) + diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index cf3710a2..83b96f88 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -770,6 +770,7 @@ def _optimize_post(model, lightweight_bmm=False): from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward from bigdl.llm.transformers.models.llama import llama_decoder_forward + from bigdl.llm.transformers.models.llama import llama_model_forward from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` @@ -823,6 +824,11 @@ def _optimize_post(model, lightweight_bmm=False): transformers.models.llama.modeling_llama.LlamaAttention, llama_attention_selective_batching_forward_4_31, ) + else: + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaModel, + llama_model_forward) else: # todo implement 4.28.0 ~ 4.30.2 pass @@ -1058,6 +1064,7 @@ def _optimize_post(model, lightweight_bmm=False): from bigdl.llm.transformers.models.qwen import qwen_attention_forward from bigdl.llm.transformers.models.qwen import qwen_mlp_forward from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward + from bigdl.llm.transformers.models.qwen import qwen_model_forward convert_forward(model, module.QWenAttention, qwen_attention_forward @@ -1068,6 +1075,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.QWenMLP, qwen_mlp_forward) + convert_forward(model, + module.QWenModel, + qwen_model_forward) elif model.config.model_type == "qwen2": # for Qwen1.5-7B modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index b6c3e9cc..cf8c9896 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -53,10 +53,15 @@ from transformers.models.llama.modeling_llama import LlamaModel from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError + try: - from transformers.cache_utils import Cache + from transformers.cache_utils import Cache, DynamicCache except ImportError: Cache = Tuple[torch.Tensor] +from transformers import logging + + +logger = logging.get_logger(__name__) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -106,7 +111,7 @@ def llama_model_forward_4_36( if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - return LlamaModel.forward( + return llama_model_forward_4_36_internal( self=self, input_ids=input_ids, attention_mask=attention_mask, @@ -1605,3 +1610,311 @@ def llama_attention_fast_forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def llama_model_forward_4_36_internal( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else \ + self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + invalidInputError(False, + "You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + invalidInputError(False, "You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, + dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \ + else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + from transformers.models.llama.modeling_llama import \ + _prepare_4d_causal_attention_mask_for_sdpa + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing." + " Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + # bigdl-llm changes: + curr_device = decoder_layer.input_layernorm.weight.device + if attention_mask is not None: + attention_mask = attention_mask.to(curr_device) + if position_ids is not None: + position_ids = position_ids.to(curr_device) + # bigdl-llm changes end + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache \ + else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, + all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + invalidInputError(False, + "You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + invalidInputError(False, "You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, + dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing." + " Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions, + padding_mask=padding_mask) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + ) + else: + # bigdl-llm changes: + # + # Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times. + # + # When the model is partitioned on two different devices using + # `accelerate`'s `dispatch``, a hook to move inputs to the correct device is + # added to each layer's `forward``, which will result in moving `attention_mask` + # and `position_ids`, which allocated on device:0, to other devices for each + # decoder layer not in device:0. + # + # To avoid this, we move `attention_mask` and `position_ids` to the device of + # the current layer before the forward call, so that the moving is only done once + # for each devices other than devie:0. + # + curr_device = decoder_layer.input_layernorm.weight.device + if attention_mask is not None: + attention_mask = attention_mask.to(curr_device) + if position_ids is not None: + position_ids = position_ids.to(curr_device) + # bigdl-llm changes end + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, + all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index a2ca1bbe..3d14048a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -45,6 +45,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_ from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from transformers.modeling_outputs import BaseModelOutputWithPast apply_rotary_emb_func = None @@ -544,3 +545,210 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: SILU, qtype )) return self.c_proj(F.silu(self.w2(x)) * self.w1(x)) + + +def qwen_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + invalidInputError( + False, + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + invalidInputError(False, "You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + if self.use_cache_quantization: + past_length = past_key_values[0][0][0].size(2) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if attention_mask is not None: + if batch_size <= 0: + invalidInputError(False, "batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + + kv_seq_len = hidden_states.size()[1] + if past_key_values[0] is not None: + # past key values[0][0] shape: bs * seq_len * head_num * dim + if self.use_cache_quantization: + kv_seq_len += past_key_values[0][0][0].shape[2] + else: + kv_seq_len += past_key_values[0][0].shape[1] + + if self.training or not self.use_dynamic_ntk: + ntk_alpha_list = [1.0] + elif kv_seq_len != hidden_states.size()[1]: + ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list + else: + ntk_alpha_list = [] + if attention_mask is not None and kv_seq_len > self.seq_length: + true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, + dtype=torch.int32) + for i in range(hidden_states.size()[0]): + true_seq_len = true_seq_lens[i].item() + ntk_alpha = self.get_ntk_alpha(true_seq_len) + ntk_alpha_list.append(ntk_alpha) + else: + ntk_alpha = self.get_ntk_alpha(kv_seq_len) + ntk_alpha_list.append(ntk_alpha) + self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list + rotary_pos_emb_list = [ + self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list + ] + + hidden_states = self.drop(hidden_states) + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. " + "Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + rotary_pos_emb_list, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + # bigdl-llm changes + curr_device = block.ln_1.weight.device + from accelerate.utils.operations import send_to_device + if rotary_pos_emb_list is not None: + rotary_pos_emb_list = send_to_device(rotary_pos_emb_list, curr_device) + if attention_mask is not None: + attention_mask = send_to_device(attention_mask, curr_device) + if head_mask[i] is not None: + head_mask[i] = send_to_device(head_mask[i], curr_device) + if encoder_hidden_states is not None: + encoder_hidden_states = send_to_device(encoder_hidden_states, curr_device) + if encoder_attention_mask is not None: + encoder_attention_mask = send_to_device(encoder_attention_mask, + curr_device) + # bigdl-llm changes ends + + outputs = block( + hidden_states, + layer_past=layer_past, + rotary_pos_emb_list=rotary_pos_emb_list, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index 371e5029..f9fbf0d6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -54,7 +54,19 @@ from bigdl.llm.transformers.kv import DynamicFp8Cache from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb +from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa +from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import BaseModelOutputWithPast +try: + from transformers.cache_utils import Cache, DynamicCache +except ImportError: + Cache = Tuple[torch.Tensor] +import logging +from transformers import logging + + +logger = logging.get_logger(__name__) KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -82,7 +94,7 @@ def qwen2_model_forward( if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - return Qwen2Model.forward( + return qwen2_model_forward_internal( self=self, input_ids=input_ids, attention_mask=attention_mask, @@ -96,6 +108,173 @@ def qwen2_model_forward( ) +def qwen2_model_forward_internal( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else \ + self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + invalidInputError(False, + "You cannot specify both decoder_input_ids and " + "decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + invalidInputError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. " + "Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, + dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + flash_attn_2 = self._attn_implementation == "flash_attention_2" + if attention_mask is not None and flash_attn_2 and use_cache: + + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + invalidInputError( + False, + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2." + " Make sure to call `tokenizer.padding_side = 'left'` before tokenizing " + "the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and + 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + # bigdl-llm changes + curr_device = decoder_layer.input_layernorm.weight.device + if attention_mask is not None: + attention_mask = attention_mask.to(curr_device) + if position_ids is not None: + position_ids = position_ids.to(curr_device) + # bigdl-llm changes end + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \ + next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, + all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def qwen2_attention_forward( self, hidden_states: torch.Tensor,