diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index 4f84662e..8edc2fef 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -23,7 +23,7 @@ Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-w Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**. Right click and select **Update Driver**. And then manually select the folder unzipped from the driver. -## Example: Predict Tokens using `generate()` API +## Example 1: Predict Tokens using `generate()` API In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs. ### 1. Install #### 1.1 Installation on Windows @@ -81,3 +81,62 @@ Inference time: xxxx s -------------------------------------------------------------------------------- done ``` + +## Example 2: Predict Tokens using `generate()` API using multi processes +In the example [llama2.py](./llama2.py), we show an experimental support for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimization and fused decoderlayer optimization on Intel NPUs. +### 1. Install +#### 1.1 Installation on Windows +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.10 +conda activate llm + +# install ipex-llm with 'all' option +pip install --pre --upgrade ipex-llm[all] +pip install --pre --upgrade bigdl-core-npu + +pip install transformers==4.40 +``` + +### 2. Runtime Configurations +For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device. +#### 2.1 Configurations for Windows + +> [!NOTE] +> For optimal performance, we recommend running code in `conhost` rather than Windows Terminal: +> - Press Win+R and input `conhost`, then press Enter to launch `conhost`. +> - Run following command to use conda in `conhost`. Replace `` with your conda install location. +> ``` +> call \Scripts\activate +> ``` + +**Following envrionment variables are required**: + +```cmd +set BIGDL_USE_NPU=1 +``` + +### 3. Running examples + +``` +torchrun --standalone --nnodes=1 --nproc-per-node=2  llama2.py +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (i.e. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun'`. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +#### Sample Output +#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) + +```log +First token cost: xxxx s, rest tokens cost average: xxxx s +Inference time: xxxx s +-------------------- Prompt -------------------- +Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun +-------------------- Output -------------------- + Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun and exciting experiences. + +One day, she decided to go on a journey to find a magical land that was said to be full of wonders +``` diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py new file mode 100644 index 00000000..55749fcf --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py @@ -0,0 +1,846 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +os.environ["OMP_NUM_THREADS"] = "4" +os.environ["IPEX_LLM_LAST_LM_HEAD"] = "1" +import torch +import time +import argparse + +from ipex_llm.transformers.npu_model import AutoModelForCausalLM +from transformers import AutoTokenizer +from intel_npu_acceleration_library.backend.factory import NNFactory +from typing import Optional, Sequence, List, Union, Any, Tuple +import numpy as np +import math +from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function +from intel_npu_acceleration_library.backend.runtime import adapt_output_tensor, _model_cache +from collections import deque +from transformers.cache_utils import Cache +from intel_npu_acceleration_library.backend.bindings import lib as backend_lib +import ctypes +from ipex_llm.utils.common import invalidInputError +from typing import Optional, List, Generator +import uuid +from functools import partial +import torch.nn.functional as F +import torch.nn.parallel +import torch.distributed as dist + +from transformers.utils import logging +logger = logging.get_logger(__name__) + + +@torch.no_grad() +def run_model( + x: Union[torch.Tensor, List[torch.Tensor]], + weights: List[torch.Tensor], + backend_cls: Any, + op_id: str, + replica: int = 1, +) -> torch.Tensor: + """Run a factory operation. Depending on the datatype of the weights it runs a float or quantized operation. + + Args: + x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). Its dtype must be torch.float16 + weights (torch.Tensor): Weights tensor. Its dtype can be torch.float16 or torch.int8 + backend_cls (Any): Backend class to run + op_id (Optional[str], optional): Operation ID. Defaults to None. + + Returns: + torch.Tensor: result + """ + global _model_cache + import time + t0 = time.perf_counter() + + # Use or not op_id depending on the class used + op_kwargs = {"op_id": op_id} if op_id else {} + + if not isinstance(x, (list, tuple)): + x = [x] + + # Reshape input + input_dtype = x[0].dtype + x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] + op_args = [] + op_args_flatten = [] + for w in weights: + if isinstance(w, tuple): # from QuantizedLinear + op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) + op_args_flatten.append(op_args[-1][0]) + op_args_flatten.append(op_args[-1][1]) + else: + op_args.append(set_contiguous(w).to(torch.float16).numpy()) + op_args_flatten.append(op_args[-1]) + + shape_dtype_signature = "_".join( + ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] + ) + key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" + models = _model_cache.get(key, None) + + input_shapes = [elem.shape for elem in x_np] + if models is None: + _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)]) + elif len(models) < 1: + _model_cache[key].append(backend_cls(*input_shapes)) + else: + _model_cache[key].rotate(1) + + # Get the model + model = _model_cache[key][0] + + with record_function(f"npu_factory_mul_{key}"): + ret = model.run(x_np, *op_args, **op_kwargs) + + if isinstance(ret, list): + results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret] + else: + results = adapt_output_tensor(ret, ret.shape, input_dtype) + + return results + + +class LowBitLlamaDecoderlayer(NNFactory): + def __init__( + self, + hidden_shape: Sequence[int], + attenion_mask_shape=None, + position_id_shape=None, + past_key_shape=None, + past_value_shape=None, + input_layernorm_shape=None, + post_layernorm_shape=None, + *, + num_heads: int, + num_key_value_heads: int, + cached_cos, + cached_sin, + mode: str = "prefill", + dtype: np.dtype = np.int8, + max_seq_len: int = 128, + profile: bool = False, + device: str = "NPU", + rms_norm_eps, + intermediate_size, + **additional_args + ): + super().__init__(profile, device) + self.max_seq_len = max_seq_len + self.intermediate_size = intermediate_size + eps = self.constant(rms_norm_eps) + + self.batch_size, self.seq_len, self.hidden_size = hidden_shape + + if mode == "decode": + invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode") + self.num_heads = num_heads + self.num_key_value_heads = num_key_value_heads + + self.head_dim = self.hidden_size // self.num_heads + + # define input, the order self.parameter matters + input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) + + # Self Attention + if mode == "decode": + attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) + else: + attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) + + position_ids = self.parameter((self.batch_size, self.seq_len)) + + input_layernorm_weight = self.parameter((1, self.hidden_size,)) + post_attention_layernorm_weight = self.parameter((1, self.hidden_size,)) + + if mode == "decode": + past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) + past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) + + residual = input + + input_2d = self.reshape(input, (self.batch_size * self.seq_len, self.hidden_size)) + + # input_layernorm forward + input_2d = self.convert_to_fp32(input_2d) + variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True) + input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) + input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) + input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) + input_2d = self.convert_to_fp16(input_2d) + + query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype) + key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype) + value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype) + + cos = self.constant(cached_cos) + cos = self.unsqueeze(cos, axis=0) + + sin = self.constant(cached_sin) + sin = self.unsqueeze(sin, axis=0) + + query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) + value_states = self.reshape(value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) + + query_states = self.transpose(query_states, [0, 2, 1, 3]) + key_states = self.transpose(key_states, [0, 2, 1, 3]) + value_states = self.transpose(value_states, [0, 2, 1, 3]) + + query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + new_key_states = key_states + new_value_states = value_states + + invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads") + + if mode == "decode": + key_states = self.concat(past_key, key_states, axis=-2) + value_states = self.concat(past_value, value_states, axis=-2) + + attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) + attn_weight = self.eltwise_add(attn_weight, attention_mask) + attn_weight = self.convert_to_fp32(attn_weight) + attn_weight = self.softmax(attn_weight, -1) + attn_weight = self.convert_to_fp16(attn_weight) + attn_output = self.matmul(attn_weight, value_states, False, False) + + attn_output = self.transpose(attn_output, [0, 2, 1, 3]) + attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) + + attn_output = self.linear(attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=dtype) + + hidden_states = self.eltwise_add(residual, attn_output) + + # Fully Connected + residual = hidden_states + hidden_states = self.convert_to_fp32(hidden_states) + variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True) + hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) + post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) + hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) + hidden_states = self.convert_to_fp16(hidden_states) + + # mlp + mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, + bias=False, wt_dtype=dtype) + mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, + bias=False, wt_dtype=dtype) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + + hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=dtype) + + hidden_states = self.eltwise_add(residual, hidden_states) + hidden_states = self.convert_to_fp16(hidden_states) + + # hacking to add key, value to outputs + new_key_states = self.convert_to_fp16(new_key_states) + new_value_states = self.convert_to_fp16(new_value_states) + + self.compile() + + def rotate_half(self, x): + x1 = self.slice(x, [0, 0, 0, 0], [self.batch_size, self.num_heads, self.seq_len, self.head_dim//2], ) + x2 = self.slice(x, [0, 0, 0, self.head_dim//2], [self.batch_size, self.num_heads, self.seq_len, self.head_dim]) + return self.concat(self.negative(x2), x1, axis=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): + position_ids = self.squeeze(position_ids) + cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) + sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) + + q_embed = self.eltwise_add(self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin)) + k_embed = self.eltwise_add(self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin)) + + return q_embed, k_embed + + +class LowBitLlamaMultiDecoderlayer(NNFactory): + def __init__( + self, + hidden_shape: Sequence[int], + *shapes, + num_heads: int, + num_key_value_heads: int, + num_layers: int, + cached_cos, + cached_sin, + input_layernorm_weights, + post_attn_layernorm_weights, + mode: str = "prefill", + dtype: np.dtype = np.int8, + max_seq_len: int = 128, + profile: bool = False, + device: str = "NPU", + rms_norm_eps, + intermediate_size, + **additional_args + ): + super().__init__(profile, device) + self.max_seq_len = max_seq_len + self.intermediate_size = intermediate_size + self.dtype = dtype + self.cached_cos = cached_cos + self.cached_sin = cached_sin + self.batch_size, self.seq_len, self.hidden_size = hidden_shape + self.mode = mode + self.rms_norm_eps = rms_norm_eps + + cos = self.constant(self.cached_cos) + self.cos = self.unsqueeze(cos, axis=0) + + sin = self.constant(self.cached_sin) + self.sin = self.unsqueeze(sin, axis=0) + + if mode == "decode": + invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode") + self.num_heads = num_heads + self.num_key_value_heads = num_key_value_heads + + self.head_dim = self.hidden_size // self.num_heads + + # define input, the order self.parameter matters + input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) + + # Self Attention + if mode == "decode": + attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) + else: + attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) + + position_ids = self.parameter((self.batch_size, self.seq_len)) + past_keys = [] + past_values = [] + if mode == "decode": + for i in range(num_layers): + past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) + past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) + past_keys.append(past_key) + past_values.append(past_value) + else: + past_key = None + past_value = None + + # input_layernorm_weight = self.parameter((1, self.hidden_size,)) + # post_attention_layernorm_weight = self.parameter((1, self.hidden_size,)) + hidden_states = input + + curr_key_values = [] + for i in range(num_layers): + hidden_states, new_key_states, new_value_states = self.build_decoder(hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + input_layernorm_weight=input_layernorm_weights[i], + post_attention_layernorm_weight=post_attn_layernorm_weights[i], + past_key=past_keys[i], + past_value=past_values[i],) + curr_key_values.append((new_key_states, new_value_states)) + + # define outputs + hidden_states = self.convert_to_fp16(hidden_states) + + for i in range(num_layers): + new_key_states = self.convert_to_fp16(curr_key_values[i][0]) + new_value_states = self.convert_to_fp16(curr_key_values[i][1]) + + self.compile() + + def build_decoder(self, hidden_states, attention_mask, position_ids, + input_layernorm_weight, post_attention_layernorm_weight, + past_key = None, + past_value = None): + + residual = hidden_states + + input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size)) + + # input layernorm + input_2d = self.convert_to_fp32(input_2d) + variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True) + eps = self.constant(self.rms_norm_eps) + input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) + input_layernorm_weight = self.constant(input_layernorm_weight) + input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) + input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) + input_2d = self.convert_to_fp16(input_2d) + + # attention + query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) + key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) + value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) + + query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) + value_states = self.reshape(value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) + + query_states = self.transpose(query_states, [0, 2, 1, 3]) + key_states = self.transpose(key_states, [0, 2, 1, 3]) + value_states = self.transpose(value_states, [0, 2, 1, 3]) + + query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, self.cos, self.sin, position_ids) + new_key_states = key_states + new_value_states = value_states + + # repeat_kv cannot be implemented because Broadcast op is needed + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads") + + if self.mode == "decode": + key_states = self.concat(past_key, key_states, axis=-2) + value_states = self.concat(past_value, value_states, axis=-2) + + attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) + attn_weight = self.eltwise_add(attn_weight, attention_mask) + attn_weight = self.convert_to_fp32(attn_weight) + attn_weight = self.softmax(attn_weight, -1) + attn_weight = self.convert_to_fp16(attn_weight) + attn_output = self.matmul(attn_weight, value_states, False, False) + + attn_output = self.transpose(attn_output, [0, 2, 1, 3]) + attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) + + attn_output = self.linear(attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype) + + hidden_states = self.eltwise_add(residual, attn_output) + + # Fully Connected + residual = hidden_states + hidden_states = self.convert_to_fp32(hidden_states) + variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True) + hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) + post_attention_layernorm_weight = self.constant(post_attention_layernorm_weight) + post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) + hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) + hidden_states = self.convert_to_fp16(hidden_states) + + # mlp + mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, + bias=False, wt_dtype=self.dtype) + mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, + bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + + hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype) + + hidden_states = self.eltwise_add(residual, hidden_states) + hidden_states = self.convert_to_fp16(hidden_states) + + return hidden_states, new_key_states, new_value_states + + def rotate_half(self, x): + x1 = self.slice(x, [0, 0, 0, 0], [self.batch_size, self.num_heads, self.seq_len, self.head_dim//2], ) + x2 = self.slice(x, [0, 0, 0, self.head_dim//2], [self.batch_size, self.num_heads, self.seq_len, self.head_dim]) + return self.concat(self.negative(x2), x1, axis=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): + position_ids = self.squeeze(position_ids) + cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) + sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) + + q_embed = self.eltwise_add(self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin)) + k_embed = self.eltwise_add(self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin)) + + return q_embed, k_embed + + +class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): + + def __init__( + self, + parameters: List[Tuple[torch.Tensor]], + input_laynorm_weights: List[torch.Tensor], + post_attn_layernorm_weights: List[torch.Tensor], + layer_indexes : List[int], + cached_cos, + cached_sin, + num_heads: int, + head_dim: int, + num_key_value_heads: int, + rms_norm_eps, + intermediate_size, + max_seq_len: int = 128, + ): + super().__init__() + + op_parameters = [] + for w in parameters: + if isinstance(w, tuple): # from QuantizedLinear + op_parameters.append((w[0].numpy(), w[1].numpy())) + else: + op_parameters.append(w.to(torch.float16).numpy()) + self.op_parameters = op_parameters + self.op_id = str(uuid.uuid4()) + # self.layer_idx = layer_idx + self.max_seq_len = max_seq_len + # self.rotary_emb = rotary_emb + if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear + np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + invalidInputError(False, "Please use int4 optimization") + + self.layer_indexes = layer_indexes + + print("create dedcoder layer") + self.backend_cls_decode = LowBitLlamaMultiDecoderlayer([1, 1, num_heads*head_dim], + input_layernorm_weights=input_laynorm_weights, + post_attn_layernorm_weights=post_attn_layernorm_weights, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + num_layers=len(layer_indexes), + max_seq_len=max_seq_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="decode", + dtype=np_dtype) + print("created dedcoder layer") + + self.backend_cls_decode.setWeights(3+len(layer_indexes)*2, self.op_id, *op_parameters) + print("weight setted") + backend_lib.run(self.backend_cls_decode._mm,) + print("first inference done") + self.kv_cache_c_parameter_handel = None + self.kv_cache_parameters = None + self.kv_cache_prefetched = False + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs,) -> torch.Tensor: + """Torch module forward method. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: result + """ + seq_len = hidden_states.shape[1] + backend_cls = self.backend_cls_decode + + pad_len = self.max_seq_len + 1 - attention_mask.size(-1) + + pad_mask = (0, pad_len) + padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask, + value=torch.finfo(torch.float16).min) + padded_attention_mask[:,:,:,-1] = 0.0 + inputs = (hidden_states.to(torch.float16), + padded_attention_mask, + position_ids,) + + if self.kv_cache_parameters is None: + self.kv_cache_parameters = [] + self.kv_cache_c_parameter_handel = None + self.kv_cache_prefetched = False + else: + # the case kv cache changed + cached_prt = self.kv_cache_parameters[0].storage().data_ptr() + current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr() + if cached_prt != current_ptr: + self.kv_cache_parameters = [] + self.kv_cache_c_parameter_handel = None + self.kv_cache_prefetched = False + + if len(self.kv_cache_parameters) == 0: + for idx in self.layer_indexes: + past_key = past_key_value.key_cache[idx] + past_value = past_key_value.value_cache[idx] + new_size = (past_key.size(0), + past_key.size(1), + self.max_seq_len, + past_key.size(3)) + past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) + past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) + + self.kv_cache_parameters.append(past_key) + self.kv_cache_parameters.append(past_value) + self.kv_cache_c_parameter_handel = self.backend_cls_decode.create_parameters([p.numpy() for p in self.kv_cache_parameters]) + + x_np = [elem.to(torch.float16).numpy() for elem in inputs] + + with record_function(f"npu_factory"): + if not self.kv_cache_prefetched: + self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel) + + for idx, elem in enumerate(x_np): + self.backend_cls_decode.set_input_tensor(elem, idx) + + backend_lib.run(self.backend_cls_decode._mm,) + ret = self.backend_cls_decode.out + results = [adapt_output_tensor(r, r.shape, torch.float16) for r in ret] + + hidden_states = results[0] + key_value_states = results[1:] + + cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len} + for i in range(len(self.layer_indexes)): + key_states, value_states = past_key_value.update(key_value_states[2*i], + key_value_states[2*i+1], + self.layer_indexes[i], cache_kwargs) + + self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel) + self.kv_cache_prefetched = True + outputs = (hidden_states,) + outputs += (past_key_value,) + + return outputs + + +class FusedLlamaLowBitDecoderlayer(torch.nn.Module): + def __init__( + self, + parameters: List[torch.Tensor], + cached_cos, + cached_sin, + layer_norm_0, + layer_norm_1, + num_heads: int, + num_key_value_heads: int, + layer_idx: int, + rms_norm_eps, + intermediate_size, + max_seq_len: int = 128, + ): + super().__init__() + self.op_parameters = parameters + self.op_id = str(uuid.uuid4()) + self.layer_idx = layer_idx + self.max_seq_len = max_seq_len + # self.rotary_emb = rotary_emb + if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear + np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + self.backend_cls_prefill = partial(LowBitLlamaDecoderlayer, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + max_seq_len=max_seq_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="prefill", + dtype=np_dtype) + self.backend_cls_decode = partial(LowBitLlamaDecoderlayer, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + max_seq_len=max_seq_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="decode", + dtype=np_dtype) + self.layer_norm_0 = layer_norm_0 + self.layer_norm_1 = layer_norm_1 + + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs,) -> torch.Tensor: + seq_len = hidden_states.shape[1] + # cos, sin = self.rotary_emb(hidden_states, position_ids) + if seq_len == 1: + backend_cls = self.backend_cls_decode + past_key = past_key_value.key_cache[self.layer_idx] + past_value = past_key_value.value_cache[self.layer_idx] + + new_size = (past_key.size(0), + past_key.size(1), + self.max_seq_len, + past_key.size(3)) + past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) + past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) + + pad_len = self.max_seq_len + 1 - attention_mask.size(-1) + + pad_mask = (0, pad_len) + padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask, + value=torch.finfo(torch.float16).min) + padded_attention_mask[:,:,:,-1] = 0.0 + inputs = (hidden_states.to(torch.float16), + padded_attention_mask, + position_ids,) + + inputs += (self.layer_norm_0, self.layer_norm_1) + + inputs += (past_key, past_value) + hidden_states, new_key, new_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=4) + cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len} + key_states, value_states = past_key_value.update(new_key, new_value, self.layer_idx, cache_kwargs) + else: + backend_cls = self.backend_cls_prefill + inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) + inputs += (self.layer_norm_0, self.layer_norm_1) + hidden_states, past_key, past_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=1) + cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len} + key_states, value_states = past_key_value.update(past_key, past_value, self.layer_idx, cache_kwargs) + + outputs = (hidden_states,) + outputs += (past_key_value,) + return outputs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for npu model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + help='The huggingface repo id for the Llama2 model to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + pipeline = True # default + max_seq_len = 1024 # default + if pipeline: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29501' + + dist.init_process_group() + my_rank = dist.get_rank() + my_size = dist.get_world_size() + logger.info(f"rank: {my_rank}, size: {my_size}") + + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager", + load_in_low_bit="sym_int4", pipeline_parallel_stages=2) + + if my_rank == 0: + print(model) + dist.barrier() + + if my_rank == 1: + print(model) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager", + load_in_low_bit="sym_int4") + + if pipeline: + layer_start = model.layer_start + layer_end = model.layer_end + num_layers = model.num_layers + else: + layer_start = 0 + layer_end = 32 + num_layers = 32 + num_heads = model.model.layers[layer_start].self_attn.num_heads + num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads + head_dim = model.model.layers[layer_start].self_attn.head_dim + rms_norm_eps = model.config.rms_norm_eps + intermediate_size = model.config.intermediate_size + deocderlayers = [] + layer_weights = [] + input_layer_norm_weights = [] + post_attn_layernorm_weights = [] + layer_indexs = range(layer_start, layer_end) + for layer_idx in layer_indexs: + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [ + # model.model.layers[i].input_layernorm.weight.to(torch.float16), + (attn_layer.q_proj.weight, attn_layer.q_proj.scale), + (attn_layer.k_proj.weight, attn_layer.k_proj.scale), + (attn_layer.v_proj.weight, attn_layer.v_proj.scale), + (attn_layer.o_proj.weight, attn_layer.o_proj.scale), + # model.model.layers[i].post_attention_layernorm.weight.to(torch.float16), + (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), + (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), + (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale)] + + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + new_decoderlayer = FusedLlamaLowBitDecoderlayer(weights, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + cached_cos=cached_cos, + cached_sin=cached_sin, + # rotary_emb=model.model.layers[i].self_attn.rotary_emb, + layer_norm_0=layer_norm_0, + layer_norm_1=layer_norm_1, + layer_idx=layer_idx, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + max_seq_len=max_seq_len) + + layer_weights.extend(weights) + input_layer_norm_weights.append(layer_norm_0) + post_attn_layernorm_weights.append(layer_norm_1) + model.model.layers[layer_idx] = new_decoderlayer + + multi_decoder = FusedLlamaLowBitMultiDecoderlayer( + parameters=layer_weights, + input_laynorm_weights=input_layer_norm_weights, + post_attn_layernorm_weights=post_attn_layernorm_weights, + layer_indexes=layer_indexs, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + head_dim=head_dim, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + max_seq_len=max_seq_len, + ) + + model.model.multi_decoder = multi_decoder + print(model) + + with torch.inference_mode(): + input_ids = tokenizer.encode(args.prompt, return_tensors="pt") + print("finish to load") + print('input length:', len(input_ids[0])) + for i in range(3): + st = time.time() + output = model.generate(input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict) + end = time.time() + if my_rank == 0: + print(f"First token cost: {model.first_token_time} s, rest tokens cost average: {model.rest_cost_mean} s") + print(f'Inference time: {end-st} s') + output_str = tokenizer.decode(output[0], skip_special_tokens=False) + print('-'*20, 'Prompt', '-'*20) + print(args.prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 2a3ecffc..444d55ce 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.utils import logger -from ipex_llm.transformers.npu_models.convert import optimize_llm +from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post def patch_flash_attn_import(filename: str) -> List[str]: @@ -84,7 +84,7 @@ class _BaseAutoModelClass: warnings.warn("`device_map` will be ignored") kwargs['device_map'] = 'cpu' - if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float]: + if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float, torch.float16]: warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used") kwargs['torch_dtype'] = torch.float @@ -114,7 +114,7 @@ class _BaseAutoModelClass: ignore_argument(kwargs, "modules_to_not_convert") ignore_argument(kwargs, "quantization_config") ignore_argument(kwargs, "speculative") - ignore_argument(kwargs, "pipeline_parallel_stages") + pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) @@ -131,12 +131,28 @@ class _BaseAutoModelClass: logger.info(f"Converting model, it may takes up to several minutes ...") + if pipeline_parallel_stages > 1: + invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages, + "Please make sure world size is same as `pipeline_parallel_stages`") + kwargs['torch_dtype'] = torch.float16 + from .npu_models.pipeline_parallel import pipeline_parallel, pipeline_parallel_generate + model = pipeline_parallel(model, pipeline_parallel_stages, + kwargs["torch_dtype"], device="cpu") + + # add pipeline_parallel_generate to pretrained model dynamically + model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate, + model) + from intel_npu_acceleration_library.compiler import create_npu_kernels with torch.no_grad(): optimize_llm(model) - cls.load_convert(qtype, model, 'cpu', *args, **kwargs) - create_npu_kernels(model) - + if pipeline_parallel_stages == 1: + cls.load_convert(qtype, model, 'cpu', *args, **kwargs) + create_npu_kernels(model) + else: + cls.load_convert(qtype, model.model, 'cpu', *args, **kwargs) + create_npu_kernels(model.model) + optimize_llm_post(model) model = model.eval() logger.info(f"Finish to convert model") diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index cd4b5fed..edc76687 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -63,7 +63,8 @@ def replace_with_QuantizedLinear(layer, qtype, device): (layer.in_features == 18944 and layer.out_features == 3584): qtype = "sym_int8_rtn" iqtype = ggml_tensor_qtype[qtype] - qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device) + qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), + iqtype, device=device) return QuantizedLinear(qweights, scale, layer.bias) @@ -79,18 +80,22 @@ def optimize_llm(model: torch.nn.Module): if model.config.model_type == "llama": from ipex_llm.transformers.npu_models.llama import merge_qkv from ipex_llm.transformers.npu_models.llama import merge_mlp - model.apply(merge_qkv) - model.apply(merge_mlp) - from ipex_llm.transformers.npu_models.llama import llama_model_forward + from ipex_llm.transformers.npu_models.llama import llama_fused_model_forward from ipex_llm.transformers.npu_models.llama import llama_attention_forward from ipex_llm.transformers.npu_models.llama import llama_mlp_forward from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaMLP - convert_forward(model, LlamaModel, llama_model_forward) - convert_forward(model, LlamaAttention, llama_attention_forward) - convert_forward(model, LlamaMLP, llama_mlp_forward) + if hasattr(model, 'pipeline_parallel_stages'): + # experimental support for fused decoderlayer implementation + convert_forward(model, LlamaModel, llama_fused_model_forward) + else: + model.apply(merge_qkv) + model.apply(merge_mlp) + convert_forward(model, LlamaModel, llama_model_forward) + convert_forward(model, LlamaAttention, llama_attention_forward) + convert_forward(model, LlamaMLP, llama_mlp_forward) elif model.config.model_type == "mistral": from ipex_llm.transformers.npu_models.mistral import merge_qkv @@ -207,3 +212,28 @@ def optimize_llm(model: torch.nn.Module): from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward convert_forward(model, module.Phi3Attention, phi3_attention_forward) + + +def optimize_llm_post(model: torch.nn.Module): + # experimental support for fused decoderlayer implementation + if model.config.model_type == "llama": + model.model.embed_tokens.to(torch.float32) + model.model.norm.to(torch.float32) + model.lm_head.to(torch.float32) + + from ipex_llm.transformers.low_bit_linear import LowBitLinear, \ + ggml_tensor_qtype, FP4Params + + if isinstance(model.lm_head, torch.nn.Linear): + new_linear = LowBitLinear(model.lm_head.in_features, + model.lm_head.out_features, + ggml_tensor_qtype["sym_int4"], + False) + paramsLowBit = FP4Params(data=model.lm_head.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + qtype=ggml_tensor_qtype["sym_int4"], + in_features=model.lm_head.in_features).to("cpu") + new_linear._parameters['weight'] = paramsLowBit + model.lm_head = new_linear diff --git a/python/llm/src/ipex_llm/transformers/npu_models/kv.py b/python/llm/src/ipex_llm/transformers/npu_models/kv.py new file mode 100644 index 00000000..ce5b29ee --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/kv.py @@ -0,0 +1,115 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +from typing import Optional, Dict, Tuple, Any +from transformers.cache_utils import DynamicCache + + +def init_fused_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): + key_cache_storage = torch.zeros(batch_size, num_heads, + max_length, head_dim, + dtype=dtype, device=device) + value_cache_storage = torch.zeros(batch_size, num_heads, + max_length, head_dim, + dtype=dtype, device=device) + + key_cache = key_cache_storage.as_strided((batch_size, num_heads, + current_length, head_dim), + key_cache_storage.stride(), + storage_offset=0) + value_cache = value_cache_storage.as_strided((batch_size, num_heads, + current_length, head_dim), + value_cache_storage.stride(), + storage_offset=0) + return key_cache, value_cache + + +def append_fused_kv_cache(cache_k, cache_v, key_states, value_states): + new_size = (cache_k.size(0), + cache_k.size(1), + cache_k.size(2) + key_states.size(2), + cache_k.size(3)) + new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) + new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states + new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) + new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states + return new_cache_k, new_cache_v + + +class DynamicFusedNormalCache(DynamicCache): + # Experimental support for fused decoderlayer implementation on NPU + # Currently only for llama2 + KV_ALLOC_BLOCK_LENGTH = 256 + + def __init__(self) -> None: + self.key_cache: Dict[int, torch.Tensor] = {} + self.value_cache: Dict[int, torch.Tensor] = {} + self._seen_tokens = 0 # Used in `generate` to keep how many tokens the cache has seen + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]]=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + batch_size, num_heads, seq_len, head_dim = key_states.shape + + max_seq_length = cache_kwargs.pop("max_seq_len", None) + transpose_value = cache_kwargs.pop("transpose_value", None) + + if layer_idx == 0 or layer_idx == 16: + if hasattr(self, "_seen_tokens"): + # 4.39 uses `_seen_tokens` + self._seen_tokens += seq_len + else: + # 4.37 uses `seen_tokens` + self.seen_tokens += seq_len + + # Update the cache + # if len(self.key_cache) <= layer_idx: + if layer_idx not in self.key_cache: + max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + \ + self.KV_ALLOC_BLOCK_LENGTH + k_cache, v_cache = init_fused_kv_cache( + batch_size, num_heads, head_dim, + 0, max_len, + key_states.dtype, key_states.device, + ) + k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states) + + self.key_cache[layer_idx] = k_cache + self.value_cache[layer_idx] = v_cache + else: + k_cache = self.key_cache[layer_idx] + v_cache = self.value_cache[layer_idx] + + kv_seq_len = k_cache.size(2) + key_states.size(2) + k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states) + self.key_cache[layer_idx] = k_cache + self.value_cache[layer_idx] = v_cache + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. + A layer index can be optionally passed.""" + + for idx, layer in self.key_cache.items(): + return layer.shape[-2] diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama.py b/python/llm/src/ipex_llm/transformers/npu_models/llama.py index ab4c2025..a322d731 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama.py @@ -182,6 +182,137 @@ def llama_model_forward( ) +def llama_fused_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + invalidInputError(False, + ("You cannot specify both input_ids and inputs_embeds at the same time, " + "and must specify either one")) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + + # ipex-llm changes start + from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache + if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache): + past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device) + # ipex-llm changes end + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, + cache_position, past_seen_tokens) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + seq_len = hidden_states.size(1) + + if seq_len == 1: + # multi_decoder = self.layers[(self.layer_end + 1) % num_layers] + layer_outputs = self.multi_decoder(hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position,) + hidden_states = layer_outputs[0] + + next_decoder_cache = layer_outputs[1] + else: + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # ipex-llm changes start + next_cache = next_decoder_cache if use_cache else None + # ipex-llm changes end + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, + all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def llama_attention_forward( self, hidden_states: torch.Tensor, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py new file mode 100644 index 00000000..69614439 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py @@ -0,0 +1,639 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py +# + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +import torch.distributed as dist +import os +import time +import numpy as np +from typing import Callable, List, Optional, Union, Tuple +from types import SimpleNamespace +import transformers +from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ipex_llm.utils.common import invalidInputError +from ipex_llm.ggml.quantize import ggml_tensor_qtype +import logging +logger = logging.getLogger(__name__) + +# patch GenerationMixin.generate +from transformers import GenerationMixin +original_generate = GenerationMixin.generate + + +class DummyLayer(nn.Module): + def __init__(self, *args): + super().__init__() + # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/ + # python/llm/src/ipex_llm/transformers/models/llama.py#L2076 + self.weight = nn.Parameter(torch.empty(0,), requires_grad=False) + + def forward(self, x): + return x + + +class Dummy_MLPLayer(nn.Module): + def __init__(self, *args): + super().__init__() + # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/ + # python/llm/src/ipex_llm/transformers/models/llama.py#L119 + self.up_proj = DummyLayer() + self.down_proj = DummyLayer() + self.shared_expert = SimpleNamespace() + self.shared_expert.up_proj = DummyLayer() + + def forward(self, x): + return x + + +class Dummy_DecoderLayer(nn.Module): + def __init__(self, *args): + super().__init__() + # to avoid AttributeError + self.input_layernorm = DummyLayer() + self.mlp = Dummy_MLPLayer() + + def forward(self, hidden_states, *args, **kwargs): + past_key_value = kwargs.get('past_key_value', None) + use_cache = kwargs.get('use_cache', False) + outputs = (hidden_states,) + if use_cache: + outputs += (past_key_value,) + return outputs + + +class Dummy_GLMBlock(nn.Module): + def __init__(self, *args): + super().__init__() + # to avoid AttributeError + self.input_layernorm = DummyLayer() + self.mlp = Dummy_MLPLayer() + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + if kv_cache is None: + return hidden_states, () + return hidden_states, kv_cache + + +def init_pipeline_parallel(): + import oneccl_bindings_for_pytorch + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + dist.init_process_group('ccl') + + +def low_mem_convert(model): + from ipex_llm.transformers.convert import convert_forward + import importlib + if 'llama' in model.config.model_type: + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaForCausalLM, + llama_causallm_forward_4_37_lowmem) + elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"): + if model.config.num_layers == 40: + # for glm4-9b + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward( + model, + module.ChatGLMForConditionalGeneration, + glm4_conditional_generation_forward_lowmem) + else: + # for chatglm3-6b + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward( + model, + module.ChatGLMForConditionalGeneration, + chatglm3_conditional_generation_forward_lowmem) + return model + + +def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32, device=None): + global num_layers + if hasattr(model.config, 'num_hidden_layers'): + num_layers = model.config.num_hidden_layers + elif hasattr(model.config, 'num_layers'): + # for chatglm3-6b + num_layers = model.config.num_layers + + slice_size = (num_layers + pipeline_parallel_stages - 1) // pipeline_parallel_stages + + local_rank = dist.get_rank() + + global layer_start + global layer_end + layer_start = slice_size * local_rank + layer_end = layer_start + min(slice_size, num_layers - layer_start) + + if model.config.model_type == "qwen" and hasattr(model.config, "visual"): + # for Qwen-VL-Chat + for i in range(num_layers): + if i < layer_start or i >= layer_end: + model._modules['transformer'].h[i] = Dummy_DecoderLayer() + if local_rank != 0: + model._modules['transformer'].wte = DummyLayer() + model._modules['transformer'].drop = DummyLayer() + if local_rank != pipeline_parallel_stages - 1: + model._modules['transformer'].ln_f = DummyLayer() + model._modules['ln_f'] = DummyLayer() + model._modules['lm_head'] = DummyLayer() + elif model.config.model_type == "chatglm": + # for chatglm3-6b, glm-4-9b-chat + for i in range(num_layers): + if i < layer_start or i >= layer_end: + model._modules['transformer'].encoder.layers[i] = Dummy_GLMBlock() + else: + model._modules['transformer'].encoder.layers[i].self_attention.num_layers = \ + i - layer_start + + if local_rank != 0: + model._modules['transformer'].embedding = DummyLayer() + if local_rank != pipeline_parallel_stages - 1: + model._modules['transformer'].encoder.final_layernorm = DummyLayer() + model._modules['transformer'].output_layer = DummyLayer() + else: + for i in range(num_layers): + if i < layer_start or i >= layer_end: + model._modules['model'].layers[i] = Dummy_DecoderLayer() + else: + model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start + + if local_rank != 0: + model._modules['model'].embed_tokens = DummyLayer() + if local_rank != pipeline_parallel_stages - 1: + model._modules['model'].norm = DummyLayer() + model._modules['lm_head'] = DummyLayer() + + _enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM') + _enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1") + if _enable_lowmem: + model = low_mem_convert(model) + + model.pipeline_parallel_stages = pipeline_parallel_stages + model.layer_start = layer_start + model.layer_end = layer_end + model.num_layers = num_layers + if torch_dtype == torch.float16: + model = model.half() + if device is None: + model = model.to(f'xpu:{local_rank}') + else: + model.to(device) + return model + + +@torch.no_grad() +def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1: + # priority: `generation_config` argument > `model.generation_config` + if generation_config is None: + if ( + self.generation_config._from_model_config + and self.generation_config._original_object_hash == hash(self.generation_config) + and self.config._has_non_default_generation_parameters() + ): + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + self.generation_config = new_generation_config + generation_config = self.generation_config + + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning("Setting `pad_token_id` to `eos_token_id`: " + f"{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id + + if generation_config is not None and generation_config.max_new_tokens is not None: + max_new_tokens = generation_config.pop("max_new_tokens") + else: + max_new_tokens = kwargs.pop("max_new_tokens", None) + + return self.pipeline_parallel_generate(inputs=inputs, + max_new_tokens=max_new_tokens, + generation_config=generation_config, + **kwargs) + + return original_generate(self, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs) + +GenerationMixin.generate = generate + + +@torch.no_grad() +def pipeline_parallel_generate(self, + inputs: Optional[torch.Tensor] = None, + max_new_tokens: int = 32, + generation_config: Optional[GenerationConfig] = None, + **kwargs): + model_kwargs = generation_config.update(**kwargs) + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + bs = inputs_tensor.shape[0] + if model_kwargs.get("attention_mask", None) is None: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id) + if self.config.is_encoder_decoder: + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=bs, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + device=inputs_tensor.device, + ) + else: + input_ids = inputs_tensor if model_input_name == "input_ids" \ + else model_kwargs.pop("input_ids") + + local_rank = dist.get_rank() + pre_rank = (local_rank - 1) % self.pipeline_parallel_stages + next_rank = (local_rank + 1) % self.pipeline_parallel_stages + + global layer_start + global layer_end + global num_layers + + self.first_token_time = 0 + self.next_token_time = [] + + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) \ + if eos_token_id is not None else None + + _input_ids = None + _past_key_values = None + + bs = input_ids.shape[0] + output_ids = input_ids.clone() + os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0" + + step = 0 + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + this_peer_finished = False + while True: + if step >= max_new_tokens: + break + + if _input_ids is None: + _input_ids = input_ids + + model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs) + + tic = time.time() + if local_rank == 0: + outputs = self(**model_inputs) + else: + _inputs_shape = _input_ids.shape + (self.config.hidden_size,) + if step == 0 and self.config.model_type == "chatglm" \ + and hasattr(self.config, "vision_config"): + # for glm-4v, image features are mapped during 1st token + # 1597 are computed according to computation process of conv + _images_feature = 1597 + _input_ids.shape[0] * 2 + _input_ids.shape[1] + _inputs_shape = (_input_ids.shape[0], _images_feature, self.config.hidden_size,) + inputs_embeds = torch.empty(_inputs_shape, + device=input_ids.device, dtype=torch.float16) + dist.recv(inputs_embeds, src=pre_rank) + model_inputs.pop("input_ids") + model_inputs["inputs_embeds"] = inputs_embeds + outputs = self(**model_inputs) + + if local_rank == self.pipeline_parallel_stages - 1: + logits = outputs.logits + next_ids = torch.argmax(logits[:, -1:, :], dim=-1) + dist.broadcast(next_ids, src=local_rank) + else: + send_data = outputs[0].to(torch.float16) + dist.send(send_data, dst=next_rank) + next_ids = torch.empty((bs, 1), device=input_ids.device, dtype=torch.int64) + dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1) + + _input_ids = next_ids + output_ids = torch.cat([output_ids, next_ids], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # finished sentences should have their next token be a padding token + next_ids = next_ids.squeeze() + if eos_token_id is not None: + if pad_token_id is None: + invalidInputError(False, "If `eos_token_id` is defined, " + "make sure that `pad_token_id` is defined.") + next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + if self.config.model_type == "chatglm" and self.config.num_layers == 40 \ + and not hasattr(self.config, "vision_config"): + # for glm-4-9b-chat + if step == 0: + value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0]) + past_key_values_placeholder = tuple( + (value_placeholder, value_placeholder) for _ in range(layer_start) + ) + (outputs.past_key_values)[: layer_end - layer_start] + tuple( + (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) + ) + _past_key_values = past_key_values_placeholder + else: + _past_key_values = outputs.past_key_values + elif self.config.model_type in ["baichuan", "chatglm"] or \ + (self.config.model_type == "qwen" and hasattr(self.config, "visual")): + # for baichuan2, chatglm3, Qwen-VL-Chat, glm-4v-9b + if local_rank != 0: + value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0]) + past_key_values_placeholder = tuple( + (value_placeholder, value_placeholder) for _ in range(layer_start) + ) + (outputs.past_key_values)[layer_start:] + _past_key_values = past_key_values_placeholder + else: + _past_key_values = outputs.past_key_values + else: + _past_key_values = outputs.past_key_values + + toc = time.time() + if step == 0: + self.first_token_time = toc - tic + else: + self.next_token_time.append(toc - tic) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_ids.tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + if this_peer_finished: + break + + step += 1 + if self.device.type == 'xpu': + torch.xpu.synchronize() + self.rest_cost_mean = np.mean(self.next_token_time) + return output_ids + + +def llama_causallm_forward_4_37_lowmem( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + # ipex-llm change starts + + device = hidden_states.device + + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) # noqa + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa + logits = torch.cat(logits, dim=-1) + else: + if device.type == "xpu": + torch.xpu.empty_cache() + logits = self.lm_head(hidden_states) + if device.type == "xpu": + torch.xpu.empty_cache() + # logits = logits.float() + + # ipex-llm change ends + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def chatglm3_conditional_generation_forward_lowmem( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, +): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + + device = hidden_states.device + # ipex-llm change starts + if device.type == "xpu": + torch.xpu.empty_cache() + lm_logits = self.transformer.output_layer(hidden_states) + if device.type == "xpu": + torch.xpu.empty_cache() + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + # lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + # ipex-llm change ends + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +def glm4_conditional_generation_forward_lowmem( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, +): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[:, -1:] + + device = hidden_states.device + # ipex-llm change starts + if device.type == "xpu": + torch.xpu.empty_cache() + lm_logits = self.transformer.output_layer(hidden_states) + if device.type == "xpu": + torch.xpu.empty_cache() + + loss = None + if labels is not None: + # lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + # ipex-llm change ends + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + )