Add experimental support of fused decoder layer for llama2 (#11768)
This commit is contained in:
		
							parent
							
								
									c28b3389e6
								
							
						
					
					
						commit
						23d3acdc77
					
				
					 7 changed files with 1850 additions and 14 deletions
				
			
		| 
						 | 
					@ -23,7 +23,7 @@ Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-w
 | 
				
			||||||
Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**.
 | 
					Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**.
 | 
				
			||||||
Right click and select **Update Driver**. And then manually select the folder unzipped from the driver.
 | 
					Right click and select **Update Driver**. And then manually select the folder unzipped from the driver.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Example: Predict Tokens using `generate()` API
 | 
					## Example 1: Predict Tokens using `generate()` API
 | 
				
			||||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs.
 | 
					In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs.
 | 
				
			||||||
### 1. Install
 | 
					### 1. Install
 | 
				
			||||||
#### 1.1 Installation on Windows
 | 
					#### 1.1 Installation on Windows
 | 
				
			||||||
| 
						 | 
					@ -81,3 +81,62 @@ Inference time: xxxx s
 | 
				
			||||||
--------------------------------------------------------------------------------
 | 
					--------------------------------------------------------------------------------
 | 
				
			||||||
done
 | 
					done
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Example 2: Predict Tokens using `generate()` API using multi processes
 | 
				
			||||||
 | 
					In the example [llama2.py](./llama2.py), we show an experimental support for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimization and fused decoderlayer optimization on Intel NPUs.
 | 
				
			||||||
 | 
					### 1. Install
 | 
				
			||||||
 | 
					#### 1.1 Installation on Windows
 | 
				
			||||||
 | 
					We suggest using conda to manage environment:
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					conda create -n llm python=3.10
 | 
				
			||||||
 | 
					conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# install ipex-llm with 'all' option
 | 
				
			||||||
 | 
					pip install --pre --upgrade ipex-llm[all]
 | 
				
			||||||
 | 
					pip install --pre --upgrade bigdl-core-npu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pip install transformers==4.40
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. Runtime Configurations
 | 
				
			||||||
 | 
					For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
 | 
				
			||||||
 | 
					#### 2.1 Configurations for Windows
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					> [!NOTE]
 | 
				
			||||||
 | 
					> For optimal performance, we recommend running code in `conhost` rather than Windows Terminal:
 | 
				
			||||||
 | 
					> - Press <kbd>Win</kbd>+<kbd>R</kbd> and input `conhost`, then press Enter to launch `conhost`.
 | 
				
			||||||
 | 
					> - Run following command to use conda in `conhost`. Replace `<your conda install location>` with your conda install location.
 | 
				
			||||||
 | 
					> ```
 | 
				
			||||||
 | 
					> call <your conda install location>\Scripts\activate
 | 
				
			||||||
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Following envrionment variables are required**:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```cmd
 | 
				
			||||||
 | 
					set BIGDL_USE_NPU=1
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 3. Running examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					torchrun --standalone --nnodes=1 --nproc-per-node=2  llama2.py
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Arguments info:
 | 
				
			||||||
 | 
					- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (i.e. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
 | 
				
			||||||
 | 
					- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun'`.
 | 
				
			||||||
 | 
					- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Sample Output
 | 
				
			||||||
 | 
					#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```log
 | 
				
			||||||
 | 
					First token cost: xxxx s, rest tokens cost average: xxxx s
 | 
				
			||||||
 | 
					Inference time: xxxx s
 | 
				
			||||||
 | 
					-------------------- Prompt --------------------
 | 
				
			||||||
 | 
					Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
 | 
				
			||||||
 | 
					-------------------- Output --------------------
 | 
				
			||||||
 | 
					<s> Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun and exciting experiences.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					One day, she decided to go on a journey to find a magical land that was said to be full of wonders
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										846
									
								
								python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										846
									
								
								python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,846 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					os.environ["OMP_NUM_THREADS"] = "4"
 | 
				
			||||||
 | 
					os.environ["IPEX_LLM_LAST_LM_HEAD"] = "1"
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer
 | 
				
			||||||
 | 
					from intel_npu_acceleration_library.backend.factory import NNFactory
 | 
				
			||||||
 | 
					from typing import Optional, Sequence, List, Union, Any, Tuple
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function
 | 
				
			||||||
 | 
					from intel_npu_acceleration_library.backend.runtime import adapt_output_tensor, _model_cache
 | 
				
			||||||
 | 
					from collections import deque
 | 
				
			||||||
 | 
					from transformers.cache_utils import Cache
 | 
				
			||||||
 | 
					from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
 | 
				
			||||||
 | 
					import ctypes
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from typing import Optional, List, Generator
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					import torch.nn.parallel
 | 
				
			||||||
 | 
					import torch.distributed as dist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def run_model(
 | 
				
			||||||
 | 
					    x: Union[torch.Tensor, List[torch.Tensor]],
 | 
				
			||||||
 | 
					    weights: List[torch.Tensor],
 | 
				
			||||||
 | 
					    backend_cls: Any,
 | 
				
			||||||
 | 
					    op_id: str,
 | 
				
			||||||
 | 
					    replica: int = 1,
 | 
				
			||||||
 | 
					) -> torch.Tensor:
 | 
				
			||||||
 | 
					    """Run a factory operation. Depending on the datatype of the weights it runs a float or quantized operation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). Its dtype must be torch.float16
 | 
				
			||||||
 | 
					        weights (torch.Tensor): Weights tensor.  Its dtype can be torch.float16 or torch.int8
 | 
				
			||||||
 | 
					        backend_cls (Any): Backend class to run
 | 
				
			||||||
 | 
					        op_id (Optional[str], optional): Operation ID. Defaults to None.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        torch.Tensor: result
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    global _model_cache
 | 
				
			||||||
 | 
					    import time
 | 
				
			||||||
 | 
					    t0 = time.perf_counter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Use or not op_id depending on the class used
 | 
				
			||||||
 | 
					    op_kwargs = {"op_id": op_id} if op_id else {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not isinstance(x, (list, tuple)):
 | 
				
			||||||
 | 
					        x = [x]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Reshape input
 | 
				
			||||||
 | 
					    input_dtype = x[0].dtype
 | 
				
			||||||
 | 
					    x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x]
 | 
				
			||||||
 | 
					    op_args = []
 | 
				
			||||||
 | 
					    op_args_flatten = []
 | 
				
			||||||
 | 
					    for w in weights:
 | 
				
			||||||
 | 
					        if isinstance(w, tuple):  # from QuantizedLinear
 | 
				
			||||||
 | 
					            op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy()))
 | 
				
			||||||
 | 
					            op_args_flatten.append(op_args[-1][0])
 | 
				
			||||||
 | 
					            op_args_flatten.append(op_args[-1][1])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            op_args.append(set_contiguous(w).to(torch.float16).numpy())
 | 
				
			||||||
 | 
					            op_args_flatten.append(op_args[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    shape_dtype_signature = "_".join(
 | 
				
			||||||
 | 
					        ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten]
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    key = f"{backend_cls.func.__name__}_{shape_dtype_signature}"
 | 
				
			||||||
 | 
					    models = _model_cache.get(key, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_shapes = [elem.shape for elem in x_np]
 | 
				
			||||||
 | 
					    if models is None:
 | 
				
			||||||
 | 
					        _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)])
 | 
				
			||||||
 | 
					    elif len(models) < 1:
 | 
				
			||||||
 | 
					        _model_cache[key].append(backend_cls(*input_shapes))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        _model_cache[key].rotate(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get the model
 | 
				
			||||||
 | 
					    model = _model_cache[key][0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with record_function(f"npu_factory_mul_{key}"):
 | 
				
			||||||
 | 
					        ret = model.run(x_np, *op_args, **op_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(ret, list):
 | 
				
			||||||
 | 
					        results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        results = adapt_output_tensor(ret, ret.shape, input_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LowBitLlamaDecoderlayer(NNFactory):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_shape: Sequence[int],
 | 
				
			||||||
 | 
					        attenion_mask_shape=None,
 | 
				
			||||||
 | 
					        position_id_shape=None,
 | 
				
			||||||
 | 
					        past_key_shape=None,
 | 
				
			||||||
 | 
					        past_value_shape=None,
 | 
				
			||||||
 | 
					        input_layernorm_shape=None,
 | 
				
			||||||
 | 
					        post_layernorm_shape=None,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        num_heads: int,
 | 
				
			||||||
 | 
					        num_key_value_heads: int,
 | 
				
			||||||
 | 
					        cached_cos,
 | 
				
			||||||
 | 
					        cached_sin,
 | 
				
			||||||
 | 
					        mode: str = "prefill",
 | 
				
			||||||
 | 
					        dtype: np.dtype = np.int8,
 | 
				
			||||||
 | 
					        max_seq_len: int = 128,
 | 
				
			||||||
 | 
					        profile: bool = False,
 | 
				
			||||||
 | 
					        device: str = "NPU",
 | 
				
			||||||
 | 
					        rms_norm_eps,
 | 
				
			||||||
 | 
					        intermediate_size,
 | 
				
			||||||
 | 
					        **additional_args
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(profile, device)
 | 
				
			||||||
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
 | 
					        self.intermediate_size = intermediate_size
 | 
				
			||||||
 | 
					        eps = self.constant(rms_norm_eps)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.batch_size, self.seq_len, self.hidden_size = hidden_shape
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if mode == "decode":
 | 
				
			||||||
 | 
					            invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode")
 | 
				
			||||||
 | 
					        self.num_heads = num_heads
 | 
				
			||||||
 | 
					        self.num_key_value_heads = num_key_value_heads
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.head_dim = self.hidden_size // self.num_heads
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # define input, the order self.parameter matters
 | 
				
			||||||
 | 
					        input = self.parameter((self.batch_size, self.seq_len, self.hidden_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Self Attention
 | 
				
			||||||
 | 
					        if mode == "decode":
 | 
				
			||||||
 | 
					            attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        position_ids = self.parameter((self.batch_size, self.seq_len))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        input_layernorm_weight = self.parameter((1, self.hidden_size,))
 | 
				
			||||||
 | 
					        post_attention_layernorm_weight = self.parameter((1, self.hidden_size,))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if mode == "decode":
 | 
				
			||||||
 | 
					            past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
 | 
				
			||||||
 | 
					            past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        residual = input
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        input_2d = self.reshape(input, (self.batch_size * self.seq_len, self.hidden_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # input_layernorm forward
 | 
				
			||||||
 | 
					        input_2d = self.convert_to_fp32(input_2d)
 | 
				
			||||||
 | 
					        variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True)
 | 
				
			||||||
 | 
					        input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps)))
 | 
				
			||||||
 | 
					        input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight)
 | 
				
			||||||
 | 
					        input_2d = self.eltwise_mul(input_layernorm_weight, input_2d)
 | 
				
			||||||
 | 
					        input_2d = self.convert_to_fp16(input_2d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype)
 | 
				
			||||||
 | 
					        key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype)
 | 
				
			||||||
 | 
					        value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cos = self.constant(cached_cos)
 | 
				
			||||||
 | 
					        cos = self.unsqueeze(cos, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sin = self.constant(cached_sin)
 | 
				
			||||||
 | 
					        sin = self.unsqueeze(sin, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
 | 
				
			||||||
 | 
					        key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
 | 
				
			||||||
 | 
					        value_states = self.reshape(value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        query_states = self.transpose(query_states, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        key_states = self.transpose(key_states, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 | 
				
			||||||
 | 
					        new_key_states = key_states
 | 
				
			||||||
 | 
					        new_value_states = value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if mode == "decode":
 | 
				
			||||||
 | 
					            key_states = self.concat(past_key, key_states, axis=-2)
 | 
				
			||||||
 | 
					            value_states = self.concat(past_value, value_states, axis=-2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
 | 
				
			||||||
 | 
					        attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
				
			||||||
 | 
					        attn_weight = self.convert_to_fp32(attn_weight)
 | 
				
			||||||
 | 
					        attn_weight = self.softmax(attn_weight, -1)
 | 
				
			||||||
 | 
					        attn_weight = self.convert_to_fp16(attn_weight)
 | 
				
			||||||
 | 
					        attn_output = self.matmul(attn_weight, value_states, False, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_output = self.transpose(attn_output, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_output = self.linear(attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_add(residual, attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Fully Connected
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp32(hidden_states)
 | 
				
			||||||
 | 
					        variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True)
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
 | 
				
			||||||
 | 
					        post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight)
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp16(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # mlp
 | 
				
			||||||
 | 
					        mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
 | 
					                          bias=False, wt_dtype=dtype)
 | 
				
			||||||
 | 
					        mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
 | 
					                          bias=False, wt_dtype=dtype)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					        mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=dtype)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_add(residual, hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp16(hidden_states)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # hacking to add key, value to outputs
 | 
				
			||||||
 | 
					        new_key_states = self.convert_to_fp16(new_key_states)
 | 
				
			||||||
 | 
					        new_value_states = self.convert_to_fp16(new_value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.compile()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def rotate_half(self, x):
 | 
				
			||||||
 | 
					        x1 = self.slice(x, [0, 0, 0, 0], [self.batch_size, self.num_heads, self.seq_len, self.head_dim//2], )
 | 
				
			||||||
 | 
					        x2 = self.slice(x, [0, 0, 0, self.head_dim//2], [self.batch_size, self.num_heads, self.seq_len, self.head_dim])
 | 
				
			||||||
 | 
					        return self.concat(self.negative(x2), x1, axis=-1)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
 | 
				
			||||||
 | 
					        position_ids = self.squeeze(position_ids)
 | 
				
			||||||
 | 
					        cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
				
			||||||
 | 
					        sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
				
			||||||
 | 
					        cos = self.unsqueeze(cos, [1])
 | 
				
			||||||
 | 
					        sin = self.unsqueeze(sin, [1])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        q_embed = self.eltwise_add(self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin))
 | 
				
			||||||
 | 
					        k_embed = self.eltwise_add(self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return q_embed, k_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LowBitLlamaMultiDecoderlayer(NNFactory):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_shape: Sequence[int],
 | 
				
			||||||
 | 
					        *shapes,
 | 
				
			||||||
 | 
					        num_heads: int,
 | 
				
			||||||
 | 
					        num_key_value_heads: int,
 | 
				
			||||||
 | 
					        num_layers: int,
 | 
				
			||||||
 | 
					        cached_cos,
 | 
				
			||||||
 | 
					        cached_sin,
 | 
				
			||||||
 | 
					        input_layernorm_weights,
 | 
				
			||||||
 | 
					        post_attn_layernorm_weights,
 | 
				
			||||||
 | 
					        mode: str = "prefill",
 | 
				
			||||||
 | 
					        dtype: np.dtype = np.int8,
 | 
				
			||||||
 | 
					        max_seq_len: int = 128,
 | 
				
			||||||
 | 
					        profile: bool = False,
 | 
				
			||||||
 | 
					        device: str = "NPU",
 | 
				
			||||||
 | 
					        rms_norm_eps,
 | 
				
			||||||
 | 
					        intermediate_size,
 | 
				
			||||||
 | 
					        **additional_args
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(profile, device)
 | 
				
			||||||
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
 | 
					        self.intermediate_size = intermediate_size
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					        self.cached_cos = cached_cos
 | 
				
			||||||
 | 
					        self.cached_sin = cached_sin
 | 
				
			||||||
 | 
					        self.batch_size, self.seq_len, self.hidden_size = hidden_shape
 | 
				
			||||||
 | 
					        self.mode = mode
 | 
				
			||||||
 | 
					        self.rms_norm_eps = rms_norm_eps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cos = self.constant(self.cached_cos)
 | 
				
			||||||
 | 
					        self.cos = self.unsqueeze(cos, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sin = self.constant(self.cached_sin)
 | 
				
			||||||
 | 
					        self.sin = self.unsqueeze(sin, axis=0)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if mode == "decode":
 | 
				
			||||||
 | 
					            invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode")
 | 
				
			||||||
 | 
					        self.num_heads = num_heads
 | 
				
			||||||
 | 
					        self.num_key_value_heads = num_key_value_heads
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.head_dim = self.hidden_size // self.num_heads
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # define input, the order self.parameter matters
 | 
				
			||||||
 | 
					        input = self.parameter((self.batch_size, self.seq_len, self.hidden_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Self Attention
 | 
				
			||||||
 | 
					        if mode == "decode":
 | 
				
			||||||
 | 
					            attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        position_ids = self.parameter((self.batch_size, self.seq_len))
 | 
				
			||||||
 | 
					        past_keys = []
 | 
				
			||||||
 | 
					        past_values = []
 | 
				
			||||||
 | 
					        if mode == "decode":
 | 
				
			||||||
 | 
					            for i in range(num_layers):
 | 
				
			||||||
 | 
					                past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
 | 
				
			||||||
 | 
					                past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
 | 
				
			||||||
 | 
					                past_keys.append(past_key)
 | 
				
			||||||
 | 
					                past_values.append(past_value)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            past_key = None
 | 
				
			||||||
 | 
					            past_value = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # input_layernorm_weight = self.parameter((1, self.hidden_size,))
 | 
				
			||||||
 | 
					        # post_attention_layernorm_weight = self.parameter((1, self.hidden_size,))
 | 
				
			||||||
 | 
					        hidden_states = input
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        curr_key_values = []
 | 
				
			||||||
 | 
					        for i in range(num_layers):
 | 
				
			||||||
 | 
					            hidden_states, new_key_states, new_value_states = self.build_decoder(hidden_states=hidden_states,
 | 
				
			||||||
 | 
					                                                                                 attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                                                                                 position_ids=position_ids,
 | 
				
			||||||
 | 
					                                                                                 input_layernorm_weight=input_layernorm_weights[i],
 | 
				
			||||||
 | 
					                                                                                 post_attention_layernorm_weight=post_attn_layernorm_weights[i],
 | 
				
			||||||
 | 
					                                                                                 past_key=past_keys[i],
 | 
				
			||||||
 | 
					                                                                                 past_value=past_values[i],)
 | 
				
			||||||
 | 
					            curr_key_values.append((new_key_states, new_value_states))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define outputs
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp16(hidden_states)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        for i in range(num_layers):
 | 
				
			||||||
 | 
					            new_key_states = self.convert_to_fp16(curr_key_values[i][0])
 | 
				
			||||||
 | 
					            new_value_states = self.convert_to_fp16(curr_key_values[i][1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def build_decoder(self, hidden_states, attention_mask, position_ids,
 | 
				
			||||||
 | 
					                      input_layernorm_weight, post_attention_layernorm_weight,
 | 
				
			||||||
 | 
					                      past_key = None,
 | 
				
			||||||
 | 
					                      past_value = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # input layernorm
 | 
				
			||||||
 | 
					        input_2d = self.convert_to_fp32(input_2d)
 | 
				
			||||||
 | 
					        variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True)
 | 
				
			||||||
 | 
					        eps = self.constant(self.rms_norm_eps)
 | 
				
			||||||
 | 
					        input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps)))
 | 
				
			||||||
 | 
					        input_layernorm_weight = self.constant(input_layernorm_weight)
 | 
				
			||||||
 | 
					        input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight)
 | 
				
			||||||
 | 
					        input_2d = self.eltwise_mul(input_layernorm_weight, input_2d)
 | 
				
			||||||
 | 
					        input_2d = self.convert_to_fp16(input_2d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # attention
 | 
				
			||||||
 | 
					        query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype)
 | 
				
			||||||
 | 
					        key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype)
 | 
				
			||||||
 | 
					        value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
 | 
				
			||||||
 | 
					        key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
 | 
				
			||||||
 | 
					        value_states = self.reshape(value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        query_states = self.transpose(query_states, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        key_states = self.transpose(key_states, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, self.cos, self.sin, position_ids)
 | 
				
			||||||
 | 
					        new_key_states = key_states
 | 
				
			||||||
 | 
					        new_value_states = value_states
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # repeat_kv cannot be implemented because Broadcast op is needed
 | 
				
			||||||
 | 
					        # key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					        # value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					        invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if self.mode == "decode":
 | 
				
			||||||
 | 
					            key_states = self.concat(past_key, key_states, axis=-2)
 | 
				
			||||||
 | 
					            value_states = self.concat(past_value, value_states, axis=-2)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
 | 
				
			||||||
 | 
					        attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
				
			||||||
 | 
					        attn_weight = self.convert_to_fp32(attn_weight)
 | 
				
			||||||
 | 
					        attn_weight = self.softmax(attn_weight, -1)
 | 
				
			||||||
 | 
					        attn_weight = self.convert_to_fp16(attn_weight)
 | 
				
			||||||
 | 
					        attn_output = self.matmul(attn_weight, value_states, False, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_output = self.transpose(attn_output, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					        attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        attn_output = self.linear(attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_add(residual, attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Fully Connected
 | 
				
			||||||
 | 
					        residual = hidden_states
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp32(hidden_states)
 | 
				
			||||||
 | 
					        variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True)
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
 | 
				
			||||||
 | 
					        post_attention_layernorm_weight = self.constant(post_attention_layernorm_weight)
 | 
				
			||||||
 | 
					        post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight)
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp16(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # mlp
 | 
				
			||||||
 | 
					        mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
 | 
					                          bias=False, wt_dtype=self.dtype)
 | 
				
			||||||
 | 
					        mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
 | 
					                          bias=False, wt_dtype=self.dtype)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					        mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        hidden_states = self.eltwise_add(residual, hidden_states)
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp16(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return hidden_states, new_key_states, new_value_states
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def rotate_half(self, x):
 | 
				
			||||||
 | 
					        x1 = self.slice(x, [0, 0, 0, 0], [self.batch_size, self.num_heads, self.seq_len, self.head_dim//2], )
 | 
				
			||||||
 | 
					        x2 = self.slice(x, [0, 0, 0, self.head_dim//2], [self.batch_size, self.num_heads, self.seq_len, self.head_dim])
 | 
				
			||||||
 | 
					        return self.concat(self.negative(x2), x1, axis=-1)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
 | 
				
			||||||
 | 
					        position_ids = self.squeeze(position_ids)
 | 
				
			||||||
 | 
					        cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
				
			||||||
 | 
					        sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
				
			||||||
 | 
					        cos = self.unsqueeze(cos, [1])
 | 
				
			||||||
 | 
					        sin = self.unsqueeze(sin, [1])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        q_embed = self.eltwise_add(self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin))
 | 
				
			||||||
 | 
					        k_embed = self.eltwise_add(self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return q_embed, k_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        parameters: List[Tuple[torch.Tensor]],
 | 
				
			||||||
 | 
					        input_laynorm_weights: List[torch.Tensor],
 | 
				
			||||||
 | 
					        post_attn_layernorm_weights: List[torch.Tensor],
 | 
				
			||||||
 | 
					        layer_indexes : List[int],
 | 
				
			||||||
 | 
					        cached_cos,
 | 
				
			||||||
 | 
					        cached_sin,
 | 
				
			||||||
 | 
					        num_heads: int,
 | 
				
			||||||
 | 
					        head_dim: int,
 | 
				
			||||||
 | 
					        num_key_value_heads: int,
 | 
				
			||||||
 | 
					        rms_norm_eps,
 | 
				
			||||||
 | 
					        intermediate_size,
 | 
				
			||||||
 | 
					        max_seq_len: int = 128,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        op_parameters = []
 | 
				
			||||||
 | 
					        for w in parameters:
 | 
				
			||||||
 | 
					            if isinstance(w, tuple):  # from QuantizedLinear
 | 
				
			||||||
 | 
					                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                op_parameters.append(w.to(torch.float16).numpy())
 | 
				
			||||||
 | 
					        self.op_parameters = op_parameters
 | 
				
			||||||
 | 
					        self.op_id = str(uuid.uuid4())
 | 
				
			||||||
 | 
					        # self.layer_idx = layer_idx
 | 
				
			||||||
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
 | 
					        # self.rotary_emb = rotary_emb
 | 
				
			||||||
 | 
					        if isinstance(parameters[0], tuple):  # weight, scale from QuantizedLinear
 | 
				
			||||||
 | 
					            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
				
			||||||
 | 
					        else:  # FP16 Linear
 | 
				
			||||||
 | 
					            invalidInputError(False, "Please use int4 optimization")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.layer_indexes = layer_indexes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("create dedcoder layer")
 | 
				
			||||||
 | 
					        self.backend_cls_decode = LowBitLlamaMultiDecoderlayer([1, 1, num_heads*head_dim],
 | 
				
			||||||
 | 
					                                          input_layernorm_weights=input_laynorm_weights,
 | 
				
			||||||
 | 
					                                          post_attn_layernorm_weights=post_attn_layernorm_weights,
 | 
				
			||||||
 | 
					                                          cached_cos=cached_cos,
 | 
				
			||||||
 | 
					                                          cached_sin=cached_sin,
 | 
				
			||||||
 | 
					                                          num_heads=num_heads,
 | 
				
			||||||
 | 
					                                          num_key_value_heads=num_key_value_heads,
 | 
				
			||||||
 | 
					                                          num_layers=len(layer_indexes),
 | 
				
			||||||
 | 
					                                          max_seq_len=max_seq_len,
 | 
				
			||||||
 | 
					                                          rms_norm_eps=rms_norm_eps,
 | 
				
			||||||
 | 
					                                          intermediate_size=intermediate_size,
 | 
				
			||||||
 | 
					                                          mode="decode",
 | 
				
			||||||
 | 
					                                          dtype=np_dtype)
 | 
				
			||||||
 | 
					        print("created dedcoder layer")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.backend_cls_decode.setWeights(3+len(layer_indexes)*2, self.op_id, *op_parameters)
 | 
				
			||||||
 | 
					        print("weight setted")
 | 
				
			||||||
 | 
					        backend_lib.run(self.backend_cls_decode._mm,)
 | 
				
			||||||
 | 
					        print("first inference done")
 | 
				
			||||||
 | 
					        self.kv_cache_c_parameter_handel = None
 | 
				
			||||||
 | 
					        self.kv_cache_parameters = None
 | 
				
			||||||
 | 
					        self.kv_cache_prefetched = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					                attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					                position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					                past_key_value: Optional[Cache] = None,
 | 
				
			||||||
 | 
					                output_attentions: bool = False,
 | 
				
			||||||
 | 
					                use_cache: bool = False,
 | 
				
			||||||
 | 
					                cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					                **kwargs,) -> torch.Tensor:
 | 
				
			||||||
 | 
					        """Torch module forward method.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            x (torch.Tensor): Input tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            torch.Tensor: result
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        seq_len = hidden_states.shape[1]
 | 
				
			||||||
 | 
					        backend_cls = self.backend_cls_decode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pad_len = self.max_seq_len + 1 - attention_mask.size(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pad_mask = (0, pad_len)
 | 
				
			||||||
 | 
					        padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask,
 | 
				
			||||||
 | 
					                                value=torch.finfo(torch.float16).min)
 | 
				
			||||||
 | 
					        padded_attention_mask[:,:,:,-1] = 0.0
 | 
				
			||||||
 | 
					        inputs = (hidden_states.to(torch.float16),
 | 
				
			||||||
 | 
					                  padded_attention_mask,
 | 
				
			||||||
 | 
					                  position_ids,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.kv_cache_parameters is None:
 | 
				
			||||||
 | 
					            self.kv_cache_parameters = []
 | 
				
			||||||
 | 
					            self.kv_cache_c_parameter_handel = None
 | 
				
			||||||
 | 
					            self.kv_cache_prefetched = False
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # the case kv cache changed
 | 
				
			||||||
 | 
					            cached_prt = self.kv_cache_parameters[0].storage().data_ptr()
 | 
				
			||||||
 | 
					            current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr()
 | 
				
			||||||
 | 
					            if cached_prt != current_ptr:
 | 
				
			||||||
 | 
					                self.kv_cache_parameters = []
 | 
				
			||||||
 | 
					                self.kv_cache_c_parameter_handel = None
 | 
				
			||||||
 | 
					                self.kv_cache_prefetched = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(self.kv_cache_parameters) == 0:
 | 
				
			||||||
 | 
					            for idx in self.layer_indexes:
 | 
				
			||||||
 | 
					                past_key = past_key_value.key_cache[idx]
 | 
				
			||||||
 | 
					                past_value = past_key_value.value_cache[idx]
 | 
				
			||||||
 | 
					                new_size = (past_key.size(0),
 | 
				
			||||||
 | 
					                            past_key.size(1),
 | 
				
			||||||
 | 
					                            self.max_seq_len,
 | 
				
			||||||
 | 
					                            past_key.size(3))
 | 
				
			||||||
 | 
					                past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					                past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.kv_cache_parameters.append(past_key)
 | 
				
			||||||
 | 
					                self.kv_cache_parameters.append(past_value)
 | 
				
			||||||
 | 
					            self.kv_cache_c_parameter_handel = self.backend_cls_decode.create_parameters([p.numpy() for p in self.kv_cache_parameters])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x_np = [elem.to(torch.float16).numpy() for elem in inputs]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with record_function(f"npu_factory"):
 | 
				
			||||||
 | 
					            if not self.kv_cache_prefetched:
 | 
				
			||||||
 | 
					                self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for idx, elem in enumerate(x_np):
 | 
				
			||||||
 | 
					                self.backend_cls_decode.set_input_tensor(elem, idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            backend_lib.run(self.backend_cls_decode._mm,)
 | 
				
			||||||
 | 
					            ret = self.backend_cls_decode.out
 | 
				
			||||||
 | 
					            results = [adapt_output_tensor(r, r.shape, torch.float16) for r in ret]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = results[0]
 | 
				
			||||||
 | 
					        key_value_states = results[1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len}
 | 
				
			||||||
 | 
					        for i in range(len(self.layer_indexes)):
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(key_value_states[2*i],
 | 
				
			||||||
 | 
					                                                             key_value_states[2*i+1],
 | 
				
			||||||
 | 
					                                                             self.layer_indexes[i], cache_kwargs)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel)
 | 
				
			||||||
 | 
					        self.kv_cache_prefetched = True
 | 
				
			||||||
 | 
					        outputs = (hidden_states,)
 | 
				
			||||||
 | 
					        outputs += (past_key_value,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        parameters: List[torch.Tensor],
 | 
				
			||||||
 | 
					        cached_cos,
 | 
				
			||||||
 | 
					        cached_sin,
 | 
				
			||||||
 | 
					        layer_norm_0,
 | 
				
			||||||
 | 
					        layer_norm_1,
 | 
				
			||||||
 | 
					        num_heads: int,
 | 
				
			||||||
 | 
					        num_key_value_heads: int,
 | 
				
			||||||
 | 
					        layer_idx: int,
 | 
				
			||||||
 | 
					        rms_norm_eps,
 | 
				
			||||||
 | 
					        intermediate_size,
 | 
				
			||||||
 | 
					        max_seq_len: int = 128,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.op_parameters = parameters
 | 
				
			||||||
 | 
					        self.op_id = str(uuid.uuid4())
 | 
				
			||||||
 | 
					        self.layer_idx = layer_idx
 | 
				
			||||||
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
 | 
					        # self.rotary_emb = rotary_emb
 | 
				
			||||||
 | 
					        if isinstance(parameters[0], tuple):  # weight, scale from QuantizedLinear
 | 
				
			||||||
 | 
					            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
				
			||||||
 | 
					        else:  # FP16 Linear
 | 
				
			||||||
 | 
					            np_dtype = np.float16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.backend_cls_prefill = partial(LowBitLlamaDecoderlayer,
 | 
				
			||||||
 | 
					                                           cached_cos=cached_cos,
 | 
				
			||||||
 | 
					                                           cached_sin=cached_sin,
 | 
				
			||||||
 | 
					                                           num_heads=num_heads,
 | 
				
			||||||
 | 
					                                           num_key_value_heads=num_key_value_heads,
 | 
				
			||||||
 | 
					                                           max_seq_len=max_seq_len,
 | 
				
			||||||
 | 
					                                           rms_norm_eps=rms_norm_eps,
 | 
				
			||||||
 | 
					                                           intermediate_size=intermediate_size,
 | 
				
			||||||
 | 
					                                           mode="prefill",
 | 
				
			||||||
 | 
					                                           dtype=np_dtype)
 | 
				
			||||||
 | 
					        self.backend_cls_decode = partial(LowBitLlamaDecoderlayer,
 | 
				
			||||||
 | 
					                                          cached_cos=cached_cos,
 | 
				
			||||||
 | 
					                                          cached_sin=cached_sin,
 | 
				
			||||||
 | 
					                                          num_heads=num_heads,
 | 
				
			||||||
 | 
					                                          num_key_value_heads=num_key_value_heads,
 | 
				
			||||||
 | 
					                                          max_seq_len=max_seq_len,
 | 
				
			||||||
 | 
					                                          rms_norm_eps=rms_norm_eps,
 | 
				
			||||||
 | 
					                                          intermediate_size=intermediate_size,
 | 
				
			||||||
 | 
					                                          mode="decode",
 | 
				
			||||||
 | 
					                                          dtype=np_dtype)
 | 
				
			||||||
 | 
					        self.layer_norm_0 = layer_norm_0
 | 
				
			||||||
 | 
					        self.layer_norm_1 = layer_norm_1
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					                attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					                position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					                past_key_value: Optional[Cache] = None,
 | 
				
			||||||
 | 
					                output_attentions: bool = False,
 | 
				
			||||||
 | 
					                use_cache: bool = False,
 | 
				
			||||||
 | 
					                cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					                **kwargs,) -> torch.Tensor:
 | 
				
			||||||
 | 
					        seq_len = hidden_states.shape[1]
 | 
				
			||||||
 | 
					        # cos, sin = self.rotary_emb(hidden_states, position_ids)
 | 
				
			||||||
 | 
					        if seq_len == 1:
 | 
				
			||||||
 | 
					            backend_cls = self.backend_cls_decode
 | 
				
			||||||
 | 
					            past_key = past_key_value.key_cache[self.layer_idx]
 | 
				
			||||||
 | 
					            past_value = past_key_value.value_cache[self.layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            new_size = (past_key.size(0),
 | 
				
			||||||
 | 
					                        past_key.size(1),
 | 
				
			||||||
 | 
					                        self.max_seq_len,
 | 
				
			||||||
 | 
					                        past_key.size(3))
 | 
				
			||||||
 | 
					            past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					            past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            pad_len = self.max_seq_len + 1 - attention_mask.size(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            pad_mask = (0, pad_len)
 | 
				
			||||||
 | 
					            padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask,
 | 
				
			||||||
 | 
					                                    value=torch.finfo(torch.float16).min)
 | 
				
			||||||
 | 
					            padded_attention_mask[:,:,:,-1] = 0.0
 | 
				
			||||||
 | 
					            inputs = (hidden_states.to(torch.float16),
 | 
				
			||||||
 | 
					                      padded_attention_mask,
 | 
				
			||||||
 | 
					                      position_ids,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            inputs += (past_key, past_value)
 | 
				
			||||||
 | 
					            hidden_states, new_key, new_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=4)
 | 
				
			||||||
 | 
					            cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len}
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(new_key, new_value, self.layer_idx, cache_kwargs)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            backend_cls = self.backend_cls_prefill
 | 
				
			||||||
 | 
					            inputs = (hidden_states.to(torch.float16), attention_mask, position_ids)
 | 
				
			||||||
 | 
					            inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
				
			||||||
 | 
					            hidden_states, past_key, past_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=1)
 | 
				
			||||||
 | 
					            cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len}
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(past_key, past_value, self.layer_idx, cache_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = (hidden_states,)
 | 
				
			||||||
 | 
					        outputs += (past_key_value,)
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for npu model')
 | 
				
			||||||
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
				
			||||||
 | 
					                        help='The huggingface repo id for the Llama2 model to be downloaded'
 | 
				
			||||||
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
 | 
					    parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
 | 
				
			||||||
 | 
					                        help='Prompt to infer')
 | 
				
			||||||
 | 
					    parser.add_argument('--n-predict', type=int, default=32,
 | 
				
			||||||
 | 
					                        help='Max tokens to predict')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipeline = True # default
 | 
				
			||||||
 | 
					    max_seq_len = 1024 # default
 | 
				
			||||||
 | 
					    if pipeline:
 | 
				
			||||||
 | 
					        os.environ['MASTER_ADDR'] = '127.0.0.1'
 | 
				
			||||||
 | 
					        os.environ['MASTER_PORT'] = '29501'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        dist.init_process_group()
 | 
				
			||||||
 | 
					        my_rank = dist.get_rank()
 | 
				
			||||||
 | 
					        my_size = dist.get_world_size()
 | 
				
			||||||
 | 
					        logger.info(f"rank: {my_rank}, size: {my_size}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager",
 | 
				
			||||||
 | 
					                                                     load_in_low_bit="sym_int4", pipeline_parallel_stages=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if my_rank == 0:
 | 
				
			||||||
 | 
					            print(model)
 | 
				
			||||||
 | 
					        dist.barrier()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if my_rank == 1:
 | 
				
			||||||
 | 
					            print(model)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager",
 | 
				
			||||||
 | 
					                                                     load_in_low_bit="sym_int4")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if pipeline:
 | 
				
			||||||
 | 
					        layer_start = model.layer_start
 | 
				
			||||||
 | 
					        layer_end = model.layer_end
 | 
				
			||||||
 | 
					        num_layers = model.num_layers
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        layer_start = 0
 | 
				
			||||||
 | 
					        layer_end = 32
 | 
				
			||||||
 | 
					        num_layers = 32
 | 
				
			||||||
 | 
					    num_heads = model.model.layers[layer_start].self_attn.num_heads
 | 
				
			||||||
 | 
					    num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
 | 
				
			||||||
 | 
					    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
				
			||||||
 | 
					    rms_norm_eps = model.config.rms_norm_eps
 | 
				
			||||||
 | 
					    intermediate_size = model.config.intermediate_size
 | 
				
			||||||
 | 
					    deocderlayers = []
 | 
				
			||||||
 | 
					    layer_weights = []
 | 
				
			||||||
 | 
					    input_layer_norm_weights = []
 | 
				
			||||||
 | 
					    post_attn_layernorm_weights = []
 | 
				
			||||||
 | 
					    layer_indexs = range(layer_start, layer_end)
 | 
				
			||||||
 | 
					    for layer_idx in layer_indexs:
 | 
				
			||||||
 | 
					        curr_layer = model.model.layers[layer_idx]
 | 
				
			||||||
 | 
					        attn_layer = curr_layer.self_attn
 | 
				
			||||||
 | 
					        mlp_layer = curr_layer.mlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        weights = [
 | 
				
			||||||
 | 
					            # model.model.layers[i].input_layernorm.weight.to(torch.float16),
 | 
				
			||||||
 | 
					            (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
				
			||||||
 | 
					            (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
				
			||||||
 | 
					            (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
				
			||||||
 | 
					            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
				
			||||||
 | 
					            # model.model.layers[i].post_attention_layernorm.weight.to(torch.float16),
 | 
				
			||||||
 | 
					            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
				
			||||||
 | 
					            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
				
			||||||
 | 
					            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
				
			||||||
 | 
					        layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_decoderlayer = FusedLlamaLowBitDecoderlayer(weights,
 | 
				
			||||||
 | 
					                                            num_heads=num_heads,
 | 
				
			||||||
 | 
					                                            num_key_value_heads=num_key_value_heads,
 | 
				
			||||||
 | 
					                                            cached_cos=cached_cos,
 | 
				
			||||||
 | 
					                                            cached_sin=cached_sin,
 | 
				
			||||||
 | 
					                                            # rotary_emb=model.model.layers[i].self_attn.rotary_emb,
 | 
				
			||||||
 | 
					                                            layer_norm_0=layer_norm_0,
 | 
				
			||||||
 | 
					                                            layer_norm_1=layer_norm_1,
 | 
				
			||||||
 | 
					                                            layer_idx=layer_idx,
 | 
				
			||||||
 | 
					                                            rms_norm_eps=rms_norm_eps,
 | 
				
			||||||
 | 
					                                            intermediate_size=intermediate_size,
 | 
				
			||||||
 | 
					                                            max_seq_len=max_seq_len)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        layer_weights.extend(weights)
 | 
				
			||||||
 | 
					        input_layer_norm_weights.append(layer_norm_0)
 | 
				
			||||||
 | 
					        post_attn_layernorm_weights.append(layer_norm_1)
 | 
				
			||||||
 | 
					        model.model.layers[layer_idx] = new_decoderlayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    multi_decoder = FusedLlamaLowBitMultiDecoderlayer(
 | 
				
			||||||
 | 
					        parameters=layer_weights,
 | 
				
			||||||
 | 
					        input_laynorm_weights=input_layer_norm_weights,
 | 
				
			||||||
 | 
					        post_attn_layernorm_weights=post_attn_layernorm_weights,
 | 
				
			||||||
 | 
					        layer_indexes=layer_indexs,
 | 
				
			||||||
 | 
					        cached_cos=cached_cos,
 | 
				
			||||||
 | 
					        cached_sin=cached_sin,
 | 
				
			||||||
 | 
					        num_heads=num_heads,
 | 
				
			||||||
 | 
					        head_dim=head_dim,
 | 
				
			||||||
 | 
					        num_key_value_heads=num_key_value_heads,
 | 
				
			||||||
 | 
					        rms_norm_eps=rms_norm_eps,
 | 
				
			||||||
 | 
					        intermediate_size=intermediate_size,
 | 
				
			||||||
 | 
					        max_seq_len=max_seq_len,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.model.multi_decoder = multi_decoder
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with torch.inference_mode():
 | 
				
			||||||
 | 
					        input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					        print("finish to load")
 | 
				
			||||||
 | 
					        print('input length:', len(input_ids[0]))
 | 
				
			||||||
 | 
					        for i in range(3):
 | 
				
			||||||
 | 
					            st = time.time()
 | 
				
			||||||
 | 
					            output = model.generate(input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict)
 | 
				
			||||||
 | 
					            end = time.time()
 | 
				
			||||||
 | 
					            if my_rank == 0:
 | 
				
			||||||
 | 
					                print(f"First token cost: {model.first_token_time} s, rest tokens cost average: {model.rest_cost_mean} s")
 | 
				
			||||||
 | 
					                print(f'Inference time: {end-st} s')
 | 
				
			||||||
 | 
					                output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
				
			||||||
 | 
					                print('-'*20, 'Prompt', '-'*20)
 | 
				
			||||||
 | 
					                print(args.prompt)
 | 
				
			||||||
 | 
					                print('-'*20, 'Output', '-'*20)
 | 
				
			||||||
 | 
					                print(output_str)
 | 
				
			||||||
| 
						 | 
					@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
from ipex_llm.transformers.utils import logger
 | 
					from ipex_llm.transformers.utils import logger
 | 
				
			||||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
 | 
					from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def patch_flash_attn_import(filename: str) -> List[str]:
 | 
					def patch_flash_attn_import(filename: str) -> List[str]:
 | 
				
			||||||
| 
						 | 
					@ -84,7 +84,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
            warnings.warn("`device_map` will be ignored")
 | 
					            warnings.warn("`device_map` will be ignored")
 | 
				
			||||||
        kwargs['device_map'] = 'cpu'
 | 
					        kwargs['device_map'] = 'cpu'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float]:
 | 
					        if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float, torch.float16]:
 | 
				
			||||||
            warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
 | 
					            warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
 | 
				
			||||||
        kwargs['torch_dtype'] = torch.float
 | 
					        kwargs['torch_dtype'] = torch.float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -114,7 +114,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        ignore_argument(kwargs, "modules_to_not_convert")
 | 
					        ignore_argument(kwargs, "modules_to_not_convert")
 | 
				
			||||||
        ignore_argument(kwargs, "quantization_config")
 | 
					        ignore_argument(kwargs, "quantization_config")
 | 
				
			||||||
        ignore_argument(kwargs, "speculative")
 | 
					        ignore_argument(kwargs, "speculative")
 | 
				
			||||||
        ignore_argument(kwargs, "pipeline_parallel_stages")
 | 
					        pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        _args = copy.deepcopy(args)
 | 
					        _args = copy.deepcopy(args)
 | 
				
			||||||
        _kwargs = copy.deepcopy(kwargs)
 | 
					        _kwargs = copy.deepcopy(kwargs)
 | 
				
			||||||
| 
						 | 
					@ -131,12 +131,28 @@ class _BaseAutoModelClass:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logger.info(f"Converting model, it may takes up to several minutes ...")
 | 
					        logger.info(f"Converting model, it may takes up to several minutes ...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if pipeline_parallel_stages > 1:
 | 
				
			||||||
 | 
					            invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages,
 | 
				
			||||||
 | 
					                              "Please make sure world size is same as `pipeline_parallel_stages`")
 | 
				
			||||||
 | 
					            kwargs['torch_dtype'] = torch.float16
 | 
				
			||||||
 | 
					            from .npu_models.pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
 | 
				
			||||||
 | 
					            model = pipeline_parallel(model, pipeline_parallel_stages,
 | 
				
			||||||
 | 
					                                      kwargs["torch_dtype"], device="cpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # add pipeline_parallel_generate to pretrained model dynamically
 | 
				
			||||||
 | 
					            model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
 | 
				
			||||||
 | 
					                                                                model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from intel_npu_acceleration_library.compiler import create_npu_kernels
 | 
					        from intel_npu_acceleration_library.compiler import create_npu_kernels
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            optimize_llm(model)
 | 
					            optimize_llm(model)
 | 
				
			||||||
 | 
					            if pipeline_parallel_stages == 1:
 | 
				
			||||||
                cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
 | 
					                cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
 | 
				
			||||||
                create_npu_kernels(model)
 | 
					                create_npu_kernels(model)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                cls.load_convert(qtype, model.model, 'cpu', *args, **kwargs)
 | 
				
			||||||
 | 
					                create_npu_kernels(model.model)
 | 
				
			||||||
 | 
					                optimize_llm_post(model)
 | 
				
			||||||
        model = model.eval()
 | 
					        model = model.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logger.info(f"Finish to convert model")
 | 
					        logger.info(f"Finish to convert model")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,7 +63,8 @@ def replace_with_QuantizedLinear(layer, qtype, device):
 | 
				
			||||||
               (layer.in_features == 18944 and layer.out_features == 3584):
 | 
					               (layer.in_features == 18944 and layer.out_features == 3584):
 | 
				
			||||||
                qtype = "sym_int8_rtn"
 | 
					                qtype = "sym_int8_rtn"
 | 
				
			||||||
                iqtype = ggml_tensor_qtype[qtype]
 | 
					                iqtype = ggml_tensor_qtype[qtype]
 | 
				
			||||||
        qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device)
 | 
					        qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
 | 
				
			||||||
 | 
					                                             iqtype, device=device)
 | 
				
			||||||
        return QuantizedLinear(qweights, scale, layer.bias)
 | 
					        return QuantizedLinear(qweights, scale, layer.bias)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -79,15 +80,19 @@ def optimize_llm(model: torch.nn.Module):
 | 
				
			||||||
    if model.config.model_type == "llama":
 | 
					    if model.config.model_type == "llama":
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.llama import merge_qkv
 | 
					        from ipex_llm.transformers.npu_models.llama import merge_qkv
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.llama import merge_mlp
 | 
					        from ipex_llm.transformers.npu_models.llama import merge_mlp
 | 
				
			||||||
        model.apply(merge_qkv)
 | 
					 | 
				
			||||||
        model.apply(merge_mlp)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from ipex_llm.transformers.npu_models.llama import llama_model_forward
 | 
					        from ipex_llm.transformers.npu_models.llama import llama_model_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.npu_models.llama import llama_fused_model_forward
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.llama import llama_attention_forward
 | 
					        from ipex_llm.transformers.npu_models.llama import llama_attention_forward
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.llama import llama_mlp_forward
 | 
					        from ipex_llm.transformers.npu_models.llama import llama_mlp_forward
 | 
				
			||||||
        from transformers.models.llama.modeling_llama import LlamaModel
 | 
					        from transformers.models.llama.modeling_llama import LlamaModel
 | 
				
			||||||
        from transformers.models.llama.modeling_llama import LlamaAttention
 | 
					        from transformers.models.llama.modeling_llama import LlamaAttention
 | 
				
			||||||
        from transformers.models.llama.modeling_llama import LlamaMLP
 | 
					        from transformers.models.llama.modeling_llama import LlamaMLP
 | 
				
			||||||
 | 
					        if hasattr(model, 'pipeline_parallel_stages'):
 | 
				
			||||||
 | 
					            # experimental support for fused decoderlayer implementation
 | 
				
			||||||
 | 
					            convert_forward(model, LlamaModel, llama_fused_model_forward)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            model.apply(merge_qkv)
 | 
				
			||||||
 | 
					            model.apply(merge_mlp)
 | 
				
			||||||
            convert_forward(model, LlamaModel, llama_model_forward)
 | 
					            convert_forward(model, LlamaModel, llama_model_forward)
 | 
				
			||||||
            convert_forward(model, LlamaAttention, llama_attention_forward)
 | 
					            convert_forward(model, LlamaAttention, llama_attention_forward)
 | 
				
			||||||
            convert_forward(model, LlamaMLP, llama_mlp_forward)
 | 
					            convert_forward(model, LlamaMLP, llama_mlp_forward)
 | 
				
			||||||
| 
						 | 
					@ -207,3 +212,28 @@ def optimize_llm(model: torch.nn.Module):
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward
 | 
					        from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        convert_forward(model, module.Phi3Attention, phi3_attention_forward)
 | 
					        convert_forward(model, module.Phi3Attention, phi3_attention_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def optimize_llm_post(model: torch.nn.Module):
 | 
				
			||||||
 | 
					    # experimental support for fused decoderlayer implementation
 | 
				
			||||||
 | 
					    if model.config.model_type == "llama":
 | 
				
			||||||
 | 
					        model.model.embed_tokens.to(torch.float32)
 | 
				
			||||||
 | 
					        model.model.norm.to(torch.float32)
 | 
				
			||||||
 | 
					        model.lm_head.to(torch.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
 | 
				
			||||||
 | 
					            ggml_tensor_qtype, FP4Params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(model.lm_head, torch.nn.Linear):
 | 
				
			||||||
 | 
					            new_linear = LowBitLinear(model.lm_head.in_features,
 | 
				
			||||||
 | 
					                                      model.lm_head.out_features,
 | 
				
			||||||
 | 
					                                      ggml_tensor_qtype["sym_int4"],
 | 
				
			||||||
 | 
					                                      False)
 | 
				
			||||||
 | 
					            paramsLowBit = FP4Params(data=model.lm_head.weight.data,
 | 
				
			||||||
 | 
					                                     requires_grad=False,
 | 
				
			||||||
 | 
					                                     quantized=False,
 | 
				
			||||||
 | 
					                                     _shape=None,
 | 
				
			||||||
 | 
					                                     qtype=ggml_tensor_qtype["sym_int4"],
 | 
				
			||||||
 | 
					                                     in_features=model.lm_head.in_features).to("cpu")
 | 
				
			||||||
 | 
					            new_linear._parameters['weight'] = paramsLowBit
 | 
				
			||||||
 | 
					            model.lm_head = new_linear
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										115
									
								
								python/llm/src/ipex_llm/transformers/npu_models/kv.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								python/llm/src/ipex_llm/transformers/npu_models/kv.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,115 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from typing import Optional, Dict, Tuple, Any
 | 
				
			||||||
 | 
					from transformers.cache_utils import DynamicCache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_fused_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
				
			||||||
 | 
					    key_cache_storage = torch.zeros(batch_size, num_heads,
 | 
				
			||||||
 | 
					                                    max_length, head_dim,
 | 
				
			||||||
 | 
					                                    dtype=dtype, device=device)
 | 
				
			||||||
 | 
					    value_cache_storage = torch.zeros(batch_size, num_heads,
 | 
				
			||||||
 | 
					                                      max_length, head_dim,
 | 
				
			||||||
 | 
					                                      dtype=dtype, device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    key_cache = key_cache_storage.as_strided((batch_size, num_heads,
 | 
				
			||||||
 | 
					                                             current_length, head_dim),
 | 
				
			||||||
 | 
					                                             key_cache_storage.stride(),
 | 
				
			||||||
 | 
					                                             storage_offset=0)
 | 
				
			||||||
 | 
					    value_cache = value_cache_storage.as_strided((batch_size, num_heads,
 | 
				
			||||||
 | 
					                                                 current_length, head_dim),
 | 
				
			||||||
 | 
					                                                 value_cache_storage.stride(),
 | 
				
			||||||
 | 
					                                                 storage_offset=0)
 | 
				
			||||||
 | 
					    return key_cache, value_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def append_fused_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
				
			||||||
 | 
					    new_size = (cache_k.size(0),
 | 
				
			||||||
 | 
					                cache_k.size(1),
 | 
				
			||||||
 | 
					                cache_k.size(2) + key_states.size(2),
 | 
				
			||||||
 | 
					                cache_k.size(3))
 | 
				
			||||||
 | 
					    new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					    new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
 | 
				
			||||||
 | 
					    new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					    new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states
 | 
				
			||||||
 | 
					    return new_cache_k, new_cache_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DynamicFusedNormalCache(DynamicCache):
 | 
				
			||||||
 | 
					    # Experimental support for fused decoderlayer implementation on NPU
 | 
				
			||||||
 | 
					    # Currently only for llama2
 | 
				
			||||||
 | 
					    KV_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self) -> None:
 | 
				
			||||||
 | 
					        self.key_cache: Dict[int, torch.Tensor] = {}
 | 
				
			||||||
 | 
					        self.value_cache: Dict[int, torch.Tensor] = {}
 | 
				
			||||||
 | 
					        self._seen_tokens = 0  # Used in `generate` to keep how many tokens the cache has seen
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        key_states: torch.Tensor,
 | 
				
			||||||
 | 
					        value_states: torch.Tensor,
 | 
				
			||||||
 | 
					        layer_idx: int,
 | 
				
			||||||
 | 
					        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
				
			||||||
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        max_seq_length = cache_kwargs.pop("max_seq_len", None)
 | 
				
			||||||
 | 
					        transpose_value = cache_kwargs.pop("transpose_value", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if layer_idx == 0 or layer_idx == 16:
 | 
				
			||||||
 | 
					            if hasattr(self, "_seen_tokens"):
 | 
				
			||||||
 | 
					                # 4.39 uses `_seen_tokens`
 | 
				
			||||||
 | 
					                self._seen_tokens += seq_len
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # 4.37 uses `seen_tokens`
 | 
				
			||||||
 | 
					                self.seen_tokens += seq_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Update the cache
 | 
				
			||||||
 | 
					        # if len(self.key_cache) <= layer_idx:
 | 
				
			||||||
 | 
					        if layer_idx not in self.key_cache:
 | 
				
			||||||
 | 
					            max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + \
 | 
				
			||||||
 | 
					                self.KV_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
 | 
					            k_cache, v_cache = init_fused_kv_cache(
 | 
				
			||||||
 | 
					                batch_size, num_heads, head_dim,
 | 
				
			||||||
 | 
					                0, max_len,
 | 
				
			||||||
 | 
					                key_states.dtype, key_states.device,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.key_cache[layer_idx] = k_cache
 | 
				
			||||||
 | 
					            self.value_cache[layer_idx] = v_cache
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            k_cache = self.key_cache[layer_idx]
 | 
				
			||||||
 | 
					            v_cache = self.value_cache[layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            kv_seq_len = k_cache.size(2) + key_states.size(2)
 | 
				
			||||||
 | 
					            k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
				
			||||||
 | 
					            self.key_cache[layer_idx] = k_cache
 | 
				
			||||||
 | 
					            self.value_cache[layer_idx] = v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 | 
				
			||||||
 | 
					        """Returns the sequence length of the cached states.
 | 
				
			||||||
 | 
					        A layer index can be optionally passed."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for idx, layer in self.key_cache.items():
 | 
				
			||||||
 | 
					            return layer.shape[-2]
 | 
				
			||||||
| 
						 | 
					@ -182,6 +182,137 @@ def llama_model_forward(
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def llama_fused_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					    cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
				
			||||||
 | 
					    output_attentions = (
 | 
				
			||||||
 | 
					        output_attentions if output_attentions is not None
 | 
				
			||||||
 | 
					        else self.config.output_attentions
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (input_ids is None) ^ (inputs_embeds is not None):
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          ("You cannot specify both input_ids and inputs_embeds at the same time, "
 | 
				
			||||||
 | 
					                           "and must specify either one"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training and use_cache:
 | 
				
			||||||
 | 
					        use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_seen_tokens = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm changes start
 | 
				
			||||||
 | 
					    from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
 | 
				
			||||||
 | 
					    if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
 | 
				
			||||||
 | 
					        past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        past_seen_tokens = past_key_values.get_seq_length()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cache_position is None:
 | 
				
			||||||
 | 
					        cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1],
 | 
				
			||||||
 | 
					                                      device=inputs_embeds.device)
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        position_ids = cache_position.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
 | 
				
			||||||
 | 
					                                           cache_position, past_seen_tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seq_len = hidden_states.size(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if seq_len == 1:
 | 
				
			||||||
 | 
					        # multi_decoder = self.layers[(self.layer_end + 1) % num_layers]
 | 
				
			||||||
 | 
					        layer_outputs = self.multi_decoder(hidden_states,
 | 
				
			||||||
 | 
					                                           attention_mask=causal_mask,
 | 
				
			||||||
 | 
					                                           position_ids=position_ids,
 | 
				
			||||||
 | 
					                                           past_key_value=past_key_values,
 | 
				
			||||||
 | 
					                                           output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                                           use_cache=use_cache,
 | 
				
			||||||
 | 
					                                           cache_position=cache_position,)
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        next_decoder_cache = layer_outputs[1]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        for decoder_layer in self.layers:
 | 
				
			||||||
 | 
					            if output_hidden_states:
 | 
				
			||||||
 | 
					                all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					                layer_outputs = self._gradient_checkpointing_func(
 | 
				
			||||||
 | 
					                    decoder_layer.__call__,
 | 
				
			||||||
 | 
					                    hidden_states,
 | 
				
			||||||
 | 
					                    causal_mask,
 | 
				
			||||||
 | 
					                    position_ids,
 | 
				
			||||||
 | 
					                    past_key_values,
 | 
				
			||||||
 | 
					                    output_attentions,
 | 
				
			||||||
 | 
					                    use_cache,
 | 
				
			||||||
 | 
					                    cache_position,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                    hidden_states,
 | 
				
			||||||
 | 
					                    attention_mask=causal_mask,
 | 
				
			||||||
 | 
					                    position_ids=position_ids,
 | 
				
			||||||
 | 
					                    past_key_value=past_key_values,
 | 
				
			||||||
 | 
					                    output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                    use_cache=use_cache,
 | 
				
			||||||
 | 
					                    cache_position=cache_position,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if use_cache:
 | 
				
			||||||
 | 
					                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if output_attentions:
 | 
				
			||||||
 | 
					                all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm changes start
 | 
				
			||||||
 | 
					    next_cache = next_decoder_cache if use_cache else None
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache,
 | 
				
			||||||
 | 
					                                 all_hidden_states, all_self_attns] if v is not None)
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def llama_attention_forward(
 | 
					def llama_attention_forward(
 | 
				
			||||||
    self,
 | 
					    self,
 | 
				
			||||||
    hidden_states: torch.Tensor,
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,639 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from torch.nn import CrossEntropyLoss
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					import torch.distributed as dist
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from typing import Callable, List, Optional, Union, Tuple
 | 
				
			||||||
 | 
					from types import SimpleNamespace
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
 | 
					from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# patch GenerationMixin.generate
 | 
				
			||||||
 | 
					from transformers import GenerationMixin
 | 
				
			||||||
 | 
					original_generate = GenerationMixin.generate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DummyLayer(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *args):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
 | 
				
			||||||
 | 
					        # python/llm/src/ipex_llm/transformers/models/llama.py#L2076
 | 
				
			||||||
 | 
					        self.weight = nn.Parameter(torch.empty(0,), requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Dummy_MLPLayer(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *args):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
 | 
				
			||||||
 | 
					        # python/llm/src/ipex_llm/transformers/models/llama.py#L119
 | 
				
			||||||
 | 
					        self.up_proj = DummyLayer()
 | 
				
			||||||
 | 
					        self.down_proj = DummyLayer()
 | 
				
			||||||
 | 
					        self.shared_expert = SimpleNamespace()
 | 
				
			||||||
 | 
					        self.shared_expert.up_proj = DummyLayer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Dummy_DecoderLayer(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *args):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # to avoid AttributeError
 | 
				
			||||||
 | 
					        self.input_layernorm = DummyLayer()
 | 
				
			||||||
 | 
					        self.mlp = Dummy_MLPLayer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, hidden_states, *args, **kwargs):
 | 
				
			||||||
 | 
					        past_key_value = kwargs.get('past_key_value', None)
 | 
				
			||||||
 | 
					        use_cache = kwargs.get('use_cache', False)
 | 
				
			||||||
 | 
					        outputs = (hidden_states,)
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            outputs += (past_key_value,)
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Dummy_GLMBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *args):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # to avoid AttributeError
 | 
				
			||||||
 | 
					        self.input_layernorm = DummyLayer()
 | 
				
			||||||
 | 
					        self.mlp = Dummy_MLPLayer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if kv_cache is None:
 | 
				
			||||||
 | 
					            return hidden_states, ()
 | 
				
			||||||
 | 
					        return hidden_states, kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_pipeline_parallel():
 | 
				
			||||||
 | 
					    import oneccl_bindings_for_pytorch
 | 
				
			||||||
 | 
					    os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
 | 
				
			||||||
 | 
					    os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
 | 
				
			||||||
 | 
					    dist.init_process_group('ccl')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def low_mem_convert(model):
 | 
				
			||||||
 | 
					    from ipex_llm.transformers.convert import convert_forward
 | 
				
			||||||
 | 
					    import importlib
 | 
				
			||||||
 | 
					    if 'llama' in model.config.model_type:
 | 
				
			||||||
 | 
					        convert_forward(
 | 
				
			||||||
 | 
					            model,
 | 
				
			||||||
 | 
					            transformers.models.llama.modeling_llama.LlamaForCausalLM,
 | 
				
			||||||
 | 
					            llama_causallm_forward_4_37_lowmem)
 | 
				
			||||||
 | 
					    elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"):
 | 
				
			||||||
 | 
					        if model.config.num_layers == 40:
 | 
				
			||||||
 | 
					            # for glm4-9b
 | 
				
			||||||
 | 
					            modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					            module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					            convert_forward(
 | 
				
			||||||
 | 
					                model,
 | 
				
			||||||
 | 
					                module.ChatGLMForConditionalGeneration,
 | 
				
			||||||
 | 
					                glm4_conditional_generation_forward_lowmem)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # for chatglm3-6b
 | 
				
			||||||
 | 
					            modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					            module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					            convert_forward(
 | 
				
			||||||
 | 
					                model,
 | 
				
			||||||
 | 
					                module.ChatGLMForConditionalGeneration,
 | 
				
			||||||
 | 
					                chatglm3_conditional_generation_forward_lowmem)
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32, device=None):
 | 
				
			||||||
 | 
					    global num_layers
 | 
				
			||||||
 | 
					    if hasattr(model.config, 'num_hidden_layers'):
 | 
				
			||||||
 | 
					        num_layers = model.config.num_hidden_layers
 | 
				
			||||||
 | 
					    elif hasattr(model.config, 'num_layers'):
 | 
				
			||||||
 | 
					        # for chatglm3-6b
 | 
				
			||||||
 | 
					        num_layers = model.config.num_layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    slice_size = (num_layers + pipeline_parallel_stages - 1) // pipeline_parallel_stages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    local_rank = dist.get_rank()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    global layer_start
 | 
				
			||||||
 | 
					    global layer_end
 | 
				
			||||||
 | 
					    layer_start = slice_size * local_rank
 | 
				
			||||||
 | 
					    layer_end = layer_start + min(slice_size, num_layers - layer_start)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
 | 
				
			||||||
 | 
					        # for Qwen-VL-Chat
 | 
				
			||||||
 | 
					        for i in range(num_layers):
 | 
				
			||||||
 | 
					            if i < layer_start or i >= layer_end:
 | 
				
			||||||
 | 
					                model._modules['transformer'].h[i] = Dummy_DecoderLayer()
 | 
				
			||||||
 | 
					        if local_rank != 0:
 | 
				
			||||||
 | 
					            model._modules['transformer'].wte = DummyLayer()
 | 
				
			||||||
 | 
					            model._modules['transformer'].drop = DummyLayer()
 | 
				
			||||||
 | 
					        if local_rank != pipeline_parallel_stages - 1:
 | 
				
			||||||
 | 
					            model._modules['transformer'].ln_f = DummyLayer()
 | 
				
			||||||
 | 
					            model._modules['ln_f'] = DummyLayer()
 | 
				
			||||||
 | 
					            model._modules['lm_head'] = DummyLayer()
 | 
				
			||||||
 | 
					    elif model.config.model_type == "chatglm":
 | 
				
			||||||
 | 
					        # for chatglm3-6b, glm-4-9b-chat
 | 
				
			||||||
 | 
					        for i in range(num_layers):
 | 
				
			||||||
 | 
					            if i < layer_start or i >= layer_end:
 | 
				
			||||||
 | 
					                model._modules['transformer'].encoder.layers[i] = Dummy_GLMBlock()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                model._modules['transformer'].encoder.layers[i].self_attention.num_layers = \
 | 
				
			||||||
 | 
					                    i - layer_start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if local_rank != 0:
 | 
				
			||||||
 | 
					            model._modules['transformer'].embedding = DummyLayer()
 | 
				
			||||||
 | 
					        if local_rank != pipeline_parallel_stages - 1:
 | 
				
			||||||
 | 
					            model._modules['transformer'].encoder.final_layernorm = DummyLayer()
 | 
				
			||||||
 | 
					            model._modules['transformer'].output_layer = DummyLayer()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        for i in range(num_layers):
 | 
				
			||||||
 | 
					            if i < layer_start or i >= layer_end:
 | 
				
			||||||
 | 
					                model._modules['model'].layers[i] = Dummy_DecoderLayer()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if local_rank != 0:
 | 
				
			||||||
 | 
					            model._modules['model'].embed_tokens = DummyLayer()
 | 
				
			||||||
 | 
					        if local_rank != pipeline_parallel_stages - 1:
 | 
				
			||||||
 | 
					            model._modules['model'].norm = DummyLayer()
 | 
				
			||||||
 | 
					            model._modules['lm_head'] = DummyLayer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM')
 | 
				
			||||||
 | 
					    _enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1")
 | 
				
			||||||
 | 
					    if _enable_lowmem:
 | 
				
			||||||
 | 
					        model = low_mem_convert(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.pipeline_parallel_stages = pipeline_parallel_stages
 | 
				
			||||||
 | 
					    model.layer_start = layer_start
 | 
				
			||||||
 | 
					    model.layer_end = layer_end
 | 
				
			||||||
 | 
					    model.num_layers = num_layers
 | 
				
			||||||
 | 
					    if torch_dtype == torch.float16:
 | 
				
			||||||
 | 
					        model = model.half()
 | 
				
			||||||
 | 
					    if device is None:
 | 
				
			||||||
 | 
					        model = model.to(f'xpu:{local_rank}')
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        model.to(device)
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def generate(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    inputs: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    generation_config: Optional[GenerationConfig] = None,
 | 
				
			||||||
 | 
					    logits_processor: Optional[LogitsProcessorList] = None,
 | 
				
			||||||
 | 
					    stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
				
			||||||
 | 
					    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
 | 
				
			||||||
 | 
					    synced_gpus: Optional[bool] = None,
 | 
				
			||||||
 | 
					    assistant_model: Optional["PreTrainedModel"] = None,
 | 
				
			||||||
 | 
					    streamer: Optional["BaseStreamer"] = None,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
 | 
				
			||||||
 | 
					        # priority: `generation_config` argument > `model.generation_config`
 | 
				
			||||||
 | 
					        if generation_config is None:
 | 
				
			||||||
 | 
					            if (
 | 
				
			||||||
 | 
					                self.generation_config._from_model_config
 | 
				
			||||||
 | 
					                and self.generation_config._original_object_hash == hash(self.generation_config)
 | 
				
			||||||
 | 
					                and self.config._has_non_default_generation_parameters()
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                new_generation_config = GenerationConfig.from_model_config(self.config)
 | 
				
			||||||
 | 
					                if new_generation_config != self.generation_config:
 | 
				
			||||||
 | 
					                    self.generation_config = new_generation_config
 | 
				
			||||||
 | 
					            generation_config = self.generation_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
 | 
				
			||||||
 | 
					            eos_token_id = generation_config.eos_token_id
 | 
				
			||||||
 | 
					            if isinstance(eos_token_id, list):
 | 
				
			||||||
 | 
					                eos_token_id = eos_token_id[0]
 | 
				
			||||||
 | 
					            logger.warning("Setting `pad_token_id` to `eos_token_id`: "
 | 
				
			||||||
 | 
					                           f"{eos_token_id} for open-end generation.")
 | 
				
			||||||
 | 
					            generation_config.pad_token_id = eos_token_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if generation_config is not None and generation_config.max_new_tokens is not None:
 | 
				
			||||||
 | 
					            max_new_tokens = generation_config.pop("max_new_tokens")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            max_new_tokens = kwargs.pop("max_new_tokens", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.pipeline_parallel_generate(inputs=inputs,
 | 
				
			||||||
 | 
					                                               max_new_tokens=max_new_tokens,
 | 
				
			||||||
 | 
					                                               generation_config=generation_config,
 | 
				
			||||||
 | 
					                                               **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return original_generate(self,
 | 
				
			||||||
 | 
					                             inputs=inputs,
 | 
				
			||||||
 | 
					                             generation_config=generation_config,
 | 
				
			||||||
 | 
					                             logits_processor=logits_processor,
 | 
				
			||||||
 | 
					                             stopping_criteria=stopping_criteria,
 | 
				
			||||||
 | 
					                             prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
				
			||||||
 | 
					                             synced_gpus=synced_gpus,
 | 
				
			||||||
 | 
					                             assistant_model=assistant_model,
 | 
				
			||||||
 | 
					                             streamer=streamer,
 | 
				
			||||||
 | 
					                             **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					GenerationMixin.generate = generate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def pipeline_parallel_generate(self,
 | 
				
			||||||
 | 
					                               inputs: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					                               max_new_tokens: int = 32,
 | 
				
			||||||
 | 
					                               generation_config: Optional[GenerationConfig] = None,
 | 
				
			||||||
 | 
					                               **kwargs):
 | 
				
			||||||
 | 
					    model_kwargs = generation_config.update(**kwargs)
 | 
				
			||||||
 | 
					    inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
 | 
				
			||||||
 | 
					        inputs, generation_config.bos_token_id, model_kwargs
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    bs = inputs_tensor.shape[0]
 | 
				
			||||||
 | 
					    if model_kwargs.get("attention_mask", None) is None:
 | 
				
			||||||
 | 
					        model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
 | 
				
			||||||
 | 
					            inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id)
 | 
				
			||||||
 | 
					    if self.config.is_encoder_decoder:
 | 
				
			||||||
 | 
					        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
 | 
				
			||||||
 | 
					            batch_size=bs,
 | 
				
			||||||
 | 
					            model_input_name=model_input_name,
 | 
				
			||||||
 | 
					            model_kwargs=model_kwargs,
 | 
				
			||||||
 | 
					            decoder_start_token_id=generation_config.decoder_start_token_id,
 | 
				
			||||||
 | 
					            bos_token_id=generation_config.bos_token_id,
 | 
				
			||||||
 | 
					            device=inputs_tensor.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        input_ids = inputs_tensor if model_input_name == "input_ids" \
 | 
				
			||||||
 | 
					            else model_kwargs.pop("input_ids")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    local_rank = dist.get_rank()
 | 
				
			||||||
 | 
					    pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
 | 
				
			||||||
 | 
					    next_rank = (local_rank + 1) % self.pipeline_parallel_stages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    global layer_start
 | 
				
			||||||
 | 
					    global layer_end
 | 
				
			||||||
 | 
					    global num_layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self.first_token_time = 0
 | 
				
			||||||
 | 
					    self.next_token_time = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pad_token_id = generation_config.pad_token_id
 | 
				
			||||||
 | 
					    eos_token_id = generation_config.eos_token_id
 | 
				
			||||||
 | 
					    if isinstance(eos_token_id, int):
 | 
				
			||||||
 | 
					        eos_token_id = [eos_token_id]
 | 
				
			||||||
 | 
					    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) \
 | 
				
			||||||
 | 
					        if eos_token_id is not None else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _input_ids = None
 | 
				
			||||||
 | 
					    _past_key_values = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bs = input_ids.shape[0]
 | 
				
			||||||
 | 
					    output_ids = input_ids.clone()
 | 
				
			||||||
 | 
					    os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    step = 0
 | 
				
			||||||
 | 
					    # keep track of which sequences are already finished
 | 
				
			||||||
 | 
					    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
 | 
				
			||||||
 | 
					    this_peer_finished = False
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        if step >= max_new_tokens:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if _input_ids is None:
 | 
				
			||||||
 | 
					            _input_ids = input_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tic = time.time()
 | 
				
			||||||
 | 
					        if local_rank == 0:
 | 
				
			||||||
 | 
					            outputs = self(**model_inputs)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _inputs_shape = _input_ids.shape + (self.config.hidden_size,)
 | 
				
			||||||
 | 
					            if step == 0 and self.config.model_type == "chatglm" \
 | 
				
			||||||
 | 
					               and hasattr(self.config, "vision_config"):
 | 
				
			||||||
 | 
					                # for glm-4v, image features are mapped during 1st token
 | 
				
			||||||
 | 
					                # 1597 are computed according to computation process of conv
 | 
				
			||||||
 | 
					                _images_feature = 1597 + _input_ids.shape[0] * 2 + _input_ids.shape[1]
 | 
				
			||||||
 | 
					                _inputs_shape = (_input_ids.shape[0], _images_feature, self.config.hidden_size,)
 | 
				
			||||||
 | 
					            inputs_embeds = torch.empty(_inputs_shape,
 | 
				
			||||||
 | 
					                                        device=input_ids.device, dtype=torch.float16)
 | 
				
			||||||
 | 
					            dist.recv(inputs_embeds, src=pre_rank)
 | 
				
			||||||
 | 
					            model_inputs.pop("input_ids")
 | 
				
			||||||
 | 
					            model_inputs["inputs_embeds"] = inputs_embeds
 | 
				
			||||||
 | 
					            outputs = self(**model_inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if local_rank == self.pipeline_parallel_stages - 1:
 | 
				
			||||||
 | 
					            logits = outputs.logits
 | 
				
			||||||
 | 
					            next_ids = torch.argmax(logits[:, -1:, :], dim=-1)
 | 
				
			||||||
 | 
					            dist.broadcast(next_ids, src=local_rank)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            send_data = outputs[0].to(torch.float16)
 | 
				
			||||||
 | 
					            dist.send(send_data, dst=next_rank)
 | 
				
			||||||
 | 
					            next_ids = torch.empty((bs, 1), device=input_ids.device, dtype=torch.int64)
 | 
				
			||||||
 | 
					            dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        _input_ids = next_ids
 | 
				
			||||||
 | 
					        output_ids = torch.cat([output_ids, next_ids], dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model_kwargs = self._update_model_kwargs_for_generation(
 | 
				
			||||||
 | 
					            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # finished sentences should have their next token be a padding token
 | 
				
			||||||
 | 
					        next_ids = next_ids.squeeze()
 | 
				
			||||||
 | 
					        if eos_token_id is not None:
 | 
				
			||||||
 | 
					            if pad_token_id is None:
 | 
				
			||||||
 | 
					                invalidInputError(False, "If `eos_token_id` is defined, "
 | 
				
			||||||
 | 
					                                         "make sure that `pad_token_id` is defined.")
 | 
				
			||||||
 | 
					            next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.config.model_type == "chatglm" and self.config.num_layers == 40 \
 | 
				
			||||||
 | 
					           and not hasattr(self.config, "vision_config"):
 | 
				
			||||||
 | 
					            # for glm-4-9b-chat
 | 
				
			||||||
 | 
					            if step == 0:
 | 
				
			||||||
 | 
					                value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
 | 
				
			||||||
 | 
					                past_key_values_placeholder = tuple(
 | 
				
			||||||
 | 
					                    (value_placeholder, value_placeholder) for _ in range(layer_start)
 | 
				
			||||||
 | 
					                ) + (outputs.past_key_values)[: layer_end - layer_start] + tuple(
 | 
				
			||||||
 | 
					                    (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                _past_key_values = past_key_values_placeholder
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                _past_key_values = outputs.past_key_values
 | 
				
			||||||
 | 
					        elif self.config.model_type in ["baichuan", "chatglm"] or \
 | 
				
			||||||
 | 
					                (self.config.model_type == "qwen" and hasattr(self.config, "visual")):
 | 
				
			||||||
 | 
					            # for baichuan2, chatglm3, Qwen-VL-Chat, glm-4v-9b
 | 
				
			||||||
 | 
					            if local_rank != 0:
 | 
				
			||||||
 | 
					                value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
 | 
				
			||||||
 | 
					                past_key_values_placeholder = tuple(
 | 
				
			||||||
 | 
					                    (value_placeholder, value_placeholder) for _ in range(layer_start)
 | 
				
			||||||
 | 
					                ) + (outputs.past_key_values)[layer_start:]
 | 
				
			||||||
 | 
					                _past_key_values = past_key_values_placeholder
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                _past_key_values = outputs.past_key_values
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _past_key_values = outputs.past_key_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        toc = time.time()
 | 
				
			||||||
 | 
					        if step == 0:
 | 
				
			||||||
 | 
					            self.first_token_time = toc - tic
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.next_token_time.append(toc - tic)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # if eos_token was found in one sentence, set sentence to finished
 | 
				
			||||||
 | 
					        if eos_token_id_tensor is not None:
 | 
				
			||||||
 | 
					            unfinished_sequences = unfinished_sequences.mul(
 | 
				
			||||||
 | 
					                next_ids.tile(eos_token_id_tensor.shape[0], 1)
 | 
				
			||||||
 | 
					                .ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            # stop when each sentence is finished
 | 
				
			||||||
 | 
					            if unfinished_sequences.max() == 0:
 | 
				
			||||||
 | 
					                this_peer_finished = True
 | 
				
			||||||
 | 
					        if this_peer_finished:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        step += 1
 | 
				
			||||||
 | 
					        if self.device.type == 'xpu':
 | 
				
			||||||
 | 
					            torch.xpu.synchronize()
 | 
				
			||||||
 | 
					    self.rest_cost_mean = np.mean(self.next_token_time)
 | 
				
			||||||
 | 
					    return output_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def llama_causallm_forward_4_37_lowmem(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    labels: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions  # noqa
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  # noqa
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 | 
				
			||||||
 | 
					    outputs = self.model(
 | 
				
			||||||
 | 
					        input_ids=input_ids,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        position_ids=position_ids,
 | 
				
			||||||
 | 
					        past_key_values=past_key_values,
 | 
				
			||||||
 | 
					        inputs_embeds=inputs_embeds,
 | 
				
			||||||
 | 
					        use_cache=use_cache,
 | 
				
			||||||
 | 
					        output_attentions=output_attentions,
 | 
				
			||||||
 | 
					        output_hidden_states=output_hidden_states,
 | 
				
			||||||
 | 
					        return_dict=return_dict,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm change starts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.config.pretraining_tp > 1:
 | 
				
			||||||
 | 
					        lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)  # noqa
 | 
				
			||||||
 | 
					        logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  # noqa
 | 
				
			||||||
 | 
					        logits = torch.cat(logits, dim=-1)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if device.type == "xpu":
 | 
				
			||||||
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					        logits = self.lm_head(hidden_states)
 | 
				
			||||||
 | 
					        if device.type == "xpu":
 | 
				
			||||||
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					    # logits = logits.float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm change ends
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    loss = None
 | 
				
			||||||
 | 
					    if labels is not None:
 | 
				
			||||||
 | 
					        # Shift so that tokens < n predict n
 | 
				
			||||||
 | 
					        shift_logits = logits[..., :-1, :].contiguous()
 | 
				
			||||||
 | 
					        shift_labels = labels[..., 1:].contiguous()
 | 
				
			||||||
 | 
					        # Flatten the tokens
 | 
				
			||||||
 | 
					        loss_fct = CrossEntropyLoss()
 | 
				
			||||||
 | 
					        shift_logits = shift_logits.view(-1, self.config.vocab_size)
 | 
				
			||||||
 | 
					        shift_labels = shift_labels.view(-1)
 | 
				
			||||||
 | 
					        # Enable model parallelism
 | 
				
			||||||
 | 
					        shift_labels = shift_labels.to(shift_logits.device)
 | 
				
			||||||
 | 
					        loss = loss_fct(shift_logits, shift_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        output = (logits,) + outputs[1:]
 | 
				
			||||||
 | 
					        return (loss,) + output if loss is not None else output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return CausalLMOutputWithPast(
 | 
				
			||||||
 | 
					        loss=loss,
 | 
				
			||||||
 | 
					        logits=logits,
 | 
				
			||||||
 | 
					        past_key_values=outputs.past_key_values,
 | 
				
			||||||
 | 
					        hidden_states=outputs.hidden_states,
 | 
				
			||||||
 | 
					        attentions=outputs.attentions,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def chatglm3_conditional_generation_forward_lowmem(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    labels: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_last_logit: Optional[bool] = False,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    transformer_outputs = self.transformer(
 | 
				
			||||||
 | 
					        input_ids=input_ids,
 | 
				
			||||||
 | 
					        position_ids=position_ids,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        past_key_values=past_key_values,
 | 
				
			||||||
 | 
					        inputs_embeds=inputs_embeds,
 | 
				
			||||||
 | 
					        use_cache=use_cache,
 | 
				
			||||||
 | 
					        output_hidden_states=output_hidden_states,
 | 
				
			||||||
 | 
					        return_dict=return_dict,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = transformer_outputs[0]
 | 
				
			||||||
 | 
					    if return_last_logit:
 | 
				
			||||||
 | 
					        hidden_states = hidden_states[-1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					    # ipex-llm change starts
 | 
				
			||||||
 | 
					    if device.type == "xpu":
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					    lm_logits = self.transformer.output_layer(hidden_states)
 | 
				
			||||||
 | 
					    if device.type == "xpu":
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					    lm_logits = lm_logits.transpose(0, 1).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    loss = None
 | 
				
			||||||
 | 
					    if labels is not None:
 | 
				
			||||||
 | 
					        # lm_logits = lm_logits.to(torch.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Shift so that tokens < n predict n
 | 
				
			||||||
 | 
					        shift_logits = lm_logits[..., :-1, :].contiguous()
 | 
				
			||||||
 | 
					        shift_labels = labels[..., 1:].contiguous()
 | 
				
			||||||
 | 
					        # Flatten the tokens
 | 
				
			||||||
 | 
					        loss_fct = CrossEntropyLoss(ignore_index=-100)
 | 
				
			||||||
 | 
					        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lm_logits = lm_logits.to(hidden_states.dtype)
 | 
				
			||||||
 | 
					        loss = loss.to(hidden_states.dtype)
 | 
				
			||||||
 | 
					    # ipex-llm change ends
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        output = (lm_logits,) + transformer_outputs[1:]
 | 
				
			||||||
 | 
					        return ((loss,) + output) if loss is not None else output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return CausalLMOutputWithPast(
 | 
				
			||||||
 | 
					        loss=loss,
 | 
				
			||||||
 | 
					        logits=lm_logits,
 | 
				
			||||||
 | 
					        past_key_values=transformer_outputs.past_key_values,
 | 
				
			||||||
 | 
					        hidden_states=transformer_outputs.hidden_states,
 | 
				
			||||||
 | 
					        attentions=transformer_outputs.attentions,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def glm4_conditional_generation_forward_lowmem(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    labels: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_last_logit: Optional[bool] = False,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    transformer_outputs = self.transformer(
 | 
				
			||||||
 | 
					        input_ids=input_ids,
 | 
				
			||||||
 | 
					        position_ids=position_ids,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        past_key_values=past_key_values,
 | 
				
			||||||
 | 
					        inputs_embeds=inputs_embeds,
 | 
				
			||||||
 | 
					        use_cache=use_cache,
 | 
				
			||||||
 | 
					        output_hidden_states=output_hidden_states,
 | 
				
			||||||
 | 
					        return_dict=return_dict,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = transformer_outputs[0]
 | 
				
			||||||
 | 
					    if return_last_logit:
 | 
				
			||||||
 | 
					        hidden_states = hidden_states[:, -1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					    # ipex-llm change starts
 | 
				
			||||||
 | 
					    if device.type == "xpu":
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					    lm_logits = self.transformer.output_layer(hidden_states)
 | 
				
			||||||
 | 
					    if device.type == "xpu":
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    loss = None
 | 
				
			||||||
 | 
					    if labels is not None:
 | 
				
			||||||
 | 
					        # lm_logits = lm_logits.to(torch.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Shift so that tokens < n predict n
 | 
				
			||||||
 | 
					        shift_logits = lm_logits[..., :-1, :].contiguous()
 | 
				
			||||||
 | 
					        shift_labels = labels[..., 1:].contiguous()
 | 
				
			||||||
 | 
					        # Flatten the tokens
 | 
				
			||||||
 | 
					        loss_fct = CrossEntropyLoss(ignore_index=-100)
 | 
				
			||||||
 | 
					        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lm_logits = lm_logits.to(hidden_states.dtype)
 | 
				
			||||||
 | 
					        loss = loss.to(hidden_states.dtype)
 | 
				
			||||||
 | 
					    # ipex-llm change ends
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        output = (lm_logits,) + transformer_outputs[1:]
 | 
				
			||||||
 | 
					        return ((loss,) + output) if loss is not None else output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return CausalLMOutputWithPast(
 | 
				
			||||||
 | 
					        loss=loss,
 | 
				
			||||||
 | 
					        logits=lm_logits,
 | 
				
			||||||
 | 
					        past_key_values=transformer_outputs.past_key_values,
 | 
				
			||||||
 | 
					        hidden_states=transformer_outputs.hidden_states,
 | 
				
			||||||
 | 
					        attentions=transformer_outputs.attentions,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
		Loading…
	
		Reference in a new issue