Add experimental support of fused decoder layer for llama2 (#11768)
This commit is contained in:
parent
c28b3389e6
commit
23d3acdc77
7 changed files with 1850 additions and 14 deletions
|
|
@ -23,7 +23,7 @@ Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-w
|
|||
Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**.
|
||||
Right click and select **Update Driver**. And then manually select the folder unzipped from the driver.
|
||||
|
||||
## Example: Predict Tokens using `generate()` API
|
||||
## Example 1: Predict Tokens using `generate()` API
|
||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs.
|
||||
### 1. Install
|
||||
#### 1.1 Installation on Windows
|
||||
|
|
@ -81,3 +81,62 @@ Inference time: xxxx s
|
|||
--------------------------------------------------------------------------------
|
||||
done
|
||||
```
|
||||
|
||||
## Example 2: Predict Tokens using `generate()` API using multi processes
|
||||
In the example [llama2.py](./llama2.py), we show an experimental support for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimization and fused decoderlayer optimization on Intel NPUs.
|
||||
### 1. Install
|
||||
#### 1.1 Installation on Windows
|
||||
We suggest using conda to manage environment:
|
||||
```bash
|
||||
conda create -n llm python=3.10
|
||||
conda activate llm
|
||||
|
||||
# install ipex-llm with 'all' option
|
||||
pip install --pre --upgrade ipex-llm[all]
|
||||
pip install --pre --upgrade bigdl-core-npu
|
||||
|
||||
pip install transformers==4.40
|
||||
```
|
||||
|
||||
### 2. Runtime Configurations
|
||||
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
|
||||
#### 2.1 Configurations for Windows
|
||||
|
||||
> [!NOTE]
|
||||
> For optimal performance, we recommend running code in `conhost` rather than Windows Terminal:
|
||||
> - Press <kbd>Win</kbd>+<kbd>R</kbd> and input `conhost`, then press Enter to launch `conhost`.
|
||||
> - Run following command to use conda in `conhost`. Replace `<your conda install location>` with your conda install location.
|
||||
> ```
|
||||
> call <your conda install location>\Scripts\activate
|
||||
> ```
|
||||
|
||||
**Following envrionment variables are required**:
|
||||
|
||||
```cmd
|
||||
set BIGDL_USE_NPU=1
|
||||
```
|
||||
|
||||
### 3. Running examples
|
||||
|
||||
```
|
||||
torchrun --standalone --nnodes=1 --nproc-per-node=2 llama2.py
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (i.e. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
|
||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun'`.
|
||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
|
||||
#### Sample Output
|
||||
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
|
||||
```log
|
||||
First token cost: xxxx s, rest tokens cost average: xxxx s
|
||||
Inference time: xxxx s
|
||||
-------------------- Prompt --------------------
|
||||
Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
|
||||
-------------------- Output --------------------
|
||||
<s> Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun and exciting experiences.
|
||||
|
||||
One day, she decided to go on a journey to find a magical land that was said to be full of wonders
|
||||
```
|
||||
|
|
|
|||
846
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py
Normal file
846
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py
Normal file
|
|
@ -0,0 +1,846 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
os.environ["OMP_NUM_THREADS"] = "4"
|
||||
os.environ["IPEX_LLM_LAST_LM_HEAD"] = "1"
|
||||
import torch
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
from intel_npu_acceleration_library.backend.factory import NNFactory
|
||||
from typing import Optional, Sequence, List, Union, Any, Tuple
|
||||
import numpy as np
|
||||
import math
|
||||
from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function
|
||||
from intel_npu_acceleration_library.backend.runtime import adapt_output_tensor, _model_cache
|
||||
from collections import deque
|
||||
from transformers.cache_utils import Cache
|
||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
|
||||
import ctypes
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from typing import Optional, List, Generator
|
||||
import uuid
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.parallel
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_model(
|
||||
x: Union[torch.Tensor, List[torch.Tensor]],
|
||||
weights: List[torch.Tensor],
|
||||
backend_cls: Any,
|
||||
op_id: str,
|
||||
replica: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Run a factory operation. Depending on the datatype of the weights it runs a float or quantized operation.
|
||||
|
||||
Args:
|
||||
x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). Its dtype must be torch.float16
|
||||
weights (torch.Tensor): Weights tensor. Its dtype can be torch.float16 or torch.int8
|
||||
backend_cls (Any): Backend class to run
|
||||
op_id (Optional[str], optional): Operation ID. Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: result
|
||||
"""
|
||||
global _model_cache
|
||||
import time
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Use or not op_id depending on the class used
|
||||
op_kwargs = {"op_id": op_id} if op_id else {}
|
||||
|
||||
if not isinstance(x, (list, tuple)):
|
||||
x = [x]
|
||||
|
||||
# Reshape input
|
||||
input_dtype = x[0].dtype
|
||||
x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x]
|
||||
op_args = []
|
||||
op_args_flatten = []
|
||||
for w in weights:
|
||||
if isinstance(w, tuple): # from QuantizedLinear
|
||||
op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy()))
|
||||
op_args_flatten.append(op_args[-1][0])
|
||||
op_args_flatten.append(op_args[-1][1])
|
||||
else:
|
||||
op_args.append(set_contiguous(w).to(torch.float16).numpy())
|
||||
op_args_flatten.append(op_args[-1])
|
||||
|
||||
shape_dtype_signature = "_".join(
|
||||
["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten]
|
||||
)
|
||||
key = f"{backend_cls.func.__name__}_{shape_dtype_signature}"
|
||||
models = _model_cache.get(key, None)
|
||||
|
||||
input_shapes = [elem.shape for elem in x_np]
|
||||
if models is None:
|
||||
_model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)])
|
||||
elif len(models) < 1:
|
||||
_model_cache[key].append(backend_cls(*input_shapes))
|
||||
else:
|
||||
_model_cache[key].rotate(1)
|
||||
|
||||
# Get the model
|
||||
model = _model_cache[key][0]
|
||||
|
||||
with record_function(f"npu_factory_mul_{key}"):
|
||||
ret = model.run(x_np, *op_args, **op_kwargs)
|
||||
|
||||
if isinstance(ret, list):
|
||||
results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret]
|
||||
else:
|
||||
results = adapt_output_tensor(ret, ret.shape, input_dtype)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class LowBitLlamaDecoderlayer(NNFactory):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_shape: Sequence[int],
|
||||
attenion_mask_shape=None,
|
||||
position_id_shape=None,
|
||||
past_key_shape=None,
|
||||
past_value_shape=None,
|
||||
input_layernorm_shape=None,
|
||||
post_layernorm_shape=None,
|
||||
*,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
cached_cos,
|
||||
cached_sin,
|
||||
mode: str = "prefill",
|
||||
dtype: np.dtype = np.int8,
|
||||
max_seq_len: int = 128,
|
||||
profile: bool = False,
|
||||
device: str = "NPU",
|
||||
rms_norm_eps,
|
||||
intermediate_size,
|
||||
**additional_args
|
||||
):
|
||||
super().__init__(profile, device)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.intermediate_size = intermediate_size
|
||||
eps = self.constant(rms_norm_eps)
|
||||
|
||||
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
|
||||
|
||||
if mode == "decode":
|
||||
invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode")
|
||||
self.num_heads = num_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
# define input, the order self.parameter matters
|
||||
input = self.parameter((self.batch_size, self.seq_len, self.hidden_size))
|
||||
|
||||
# Self Attention
|
||||
if mode == "decode":
|
||||
attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1))
|
||||
else:
|
||||
attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len))
|
||||
|
||||
position_ids = self.parameter((self.batch_size, self.seq_len))
|
||||
|
||||
input_layernorm_weight = self.parameter((1, self.hidden_size,))
|
||||
post_attention_layernorm_weight = self.parameter((1, self.hidden_size,))
|
||||
|
||||
if mode == "decode":
|
||||
past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
|
||||
past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
|
||||
|
||||
residual = input
|
||||
|
||||
input_2d = self.reshape(input, (self.batch_size * self.seq_len, self.hidden_size))
|
||||
|
||||
# input_layernorm forward
|
||||
input_2d = self.convert_to_fp32(input_2d)
|
||||
variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True)
|
||||
input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps)))
|
||||
input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight)
|
||||
input_2d = self.eltwise_mul(input_layernorm_weight, input_2d)
|
||||
input_2d = self.convert_to_fp16(input_2d)
|
||||
|
||||
query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype)
|
||||
key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype)
|
||||
value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype)
|
||||
|
||||
cos = self.constant(cached_cos)
|
||||
cos = self.unsqueeze(cos, axis=0)
|
||||
|
||||
sin = self.constant(cached_sin)
|
||||
sin = self.unsqueeze(sin, axis=0)
|
||||
|
||||
query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
||||
key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
|
||||
value_states = self.reshape(value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
|
||||
|
||||
query_states = self.transpose(query_states, [0, 2, 1, 3])
|
||||
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
||||
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
new_key_states = key_states
|
||||
new_value_states = value_states
|
||||
|
||||
invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads")
|
||||
|
||||
if mode == "decode":
|
||||
key_states = self.concat(past_key, key_states, axis=-2)
|
||||
value_states = self.concat(past_value, value_states, axis=-2)
|
||||
|
||||
attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
|
||||
attn_weight = self.eltwise_add(attn_weight, attention_mask)
|
||||
attn_weight = self.convert_to_fp32(attn_weight)
|
||||
attn_weight = self.softmax(attn_weight, -1)
|
||||
attn_weight = self.convert_to_fp16(attn_weight)
|
||||
attn_output = self.matmul(attn_weight, value_states, False, False)
|
||||
|
||||
attn_output = self.transpose(attn_output, [0, 2, 1, 3])
|
||||
attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size])
|
||||
|
||||
attn_output = self.linear(attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=dtype)
|
||||
|
||||
hidden_states = self.eltwise_add(residual, attn_output)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.convert_to_fp32(hidden_states)
|
||||
variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True)
|
||||
hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
|
||||
post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight)
|
||||
hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states)
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
||||
# mlp
|
||||
mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
bias=False, wt_dtype=dtype)
|
||||
mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
bias=False, wt_dtype=dtype) # type: ignore[attr-defined]
|
||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
|
||||
hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=dtype)
|
||||
|
||||
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
||||
# hacking to add key, value to outputs
|
||||
new_key_states = self.convert_to_fp16(new_key_states)
|
||||
new_value_states = self.convert_to_fp16(new_value_states)
|
||||
|
||||
self.compile()
|
||||
|
||||
def rotate_half(self, x):
|
||||
x1 = self.slice(x, [0, 0, 0, 0], [self.batch_size, self.num_heads, self.seq_len, self.head_dim//2], )
|
||||
x2 = self.slice(x, [0, 0, 0, self.head_dim//2], [self.batch_size, self.num_heads, self.seq_len, self.head_dim])
|
||||
return self.concat(self.negative(x2), x1, axis=-1)
|
||||
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
|
||||
position_ids = self.squeeze(position_ids)
|
||||
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||
cos = self.unsqueeze(cos, [1])
|
||||
sin = self.unsqueeze(sin, [1])
|
||||
|
||||
q_embed = self.eltwise_add(self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin))
|
||||
k_embed = self.eltwise_add(self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin))
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class LowBitLlamaMultiDecoderlayer(NNFactory):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_shape: Sequence[int],
|
||||
*shapes,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
num_layers: int,
|
||||
cached_cos,
|
||||
cached_sin,
|
||||
input_layernorm_weights,
|
||||
post_attn_layernorm_weights,
|
||||
mode: str = "prefill",
|
||||
dtype: np.dtype = np.int8,
|
||||
max_seq_len: int = 128,
|
||||
profile: bool = False,
|
||||
device: str = "NPU",
|
||||
rms_norm_eps,
|
||||
intermediate_size,
|
||||
**additional_args
|
||||
):
|
||||
super().__init__(profile, device)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dtype = dtype
|
||||
self.cached_cos = cached_cos
|
||||
self.cached_sin = cached_sin
|
||||
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
|
||||
self.mode = mode
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
cos = self.constant(self.cached_cos)
|
||||
self.cos = self.unsqueeze(cos, axis=0)
|
||||
|
||||
sin = self.constant(self.cached_sin)
|
||||
self.sin = self.unsqueeze(sin, axis=0)
|
||||
|
||||
if mode == "decode":
|
||||
invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode")
|
||||
self.num_heads = num_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
# define input, the order self.parameter matters
|
||||
input = self.parameter((self.batch_size, self.seq_len, self.hidden_size))
|
||||
|
||||
# Self Attention
|
||||
if mode == "decode":
|
||||
attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1))
|
||||
else:
|
||||
attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len))
|
||||
|
||||
position_ids = self.parameter((self.batch_size, self.seq_len))
|
||||
past_keys = []
|
||||
past_values = []
|
||||
if mode == "decode":
|
||||
for i in range(num_layers):
|
||||
past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
|
||||
past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim))
|
||||
past_keys.append(past_key)
|
||||
past_values.append(past_value)
|
||||
else:
|
||||
past_key = None
|
||||
past_value = None
|
||||
|
||||
# input_layernorm_weight = self.parameter((1, self.hidden_size,))
|
||||
# post_attention_layernorm_weight = self.parameter((1, self.hidden_size,))
|
||||
hidden_states = input
|
||||
|
||||
curr_key_values = []
|
||||
for i in range(num_layers):
|
||||
hidden_states, new_key_states, new_value_states = self.build_decoder(hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
input_layernorm_weight=input_layernorm_weights[i],
|
||||
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
||||
past_key=past_keys[i],
|
||||
past_value=past_values[i],)
|
||||
curr_key_values.append((new_key_states, new_value_states))
|
||||
|
||||
# define outputs
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
||||
for i in range(num_layers):
|
||||
new_key_states = self.convert_to_fp16(curr_key_values[i][0])
|
||||
new_value_states = self.convert_to_fp16(curr_key_values[i][1])
|
||||
|
||||
self.compile()
|
||||
|
||||
def build_decoder(self, hidden_states, attention_mask, position_ids,
|
||||
input_layernorm_weight, post_attention_layernorm_weight,
|
||||
past_key = None,
|
||||
past_value = None):
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size))
|
||||
|
||||
# input layernorm
|
||||
input_2d = self.convert_to_fp32(input_2d)
|
||||
variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True)
|
||||
eps = self.constant(self.rms_norm_eps)
|
||||
input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps)))
|
||||
input_layernorm_weight = self.constant(input_layernorm_weight)
|
||||
input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight)
|
||||
input_2d = self.eltwise_mul(input_layernorm_weight, input_2d)
|
||||
input_2d = self.convert_to_fp16(input_2d)
|
||||
|
||||
# attention
|
||||
query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype)
|
||||
key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype)
|
||||
value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype)
|
||||
|
||||
query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
||||
key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
|
||||
value_states = self.reshape(value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim])
|
||||
|
||||
query_states = self.transpose(query_states, [0, 2, 1, 3])
|
||||
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
||||
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, self.cos, self.sin, position_ids)
|
||||
new_key_states = key_states
|
||||
new_value_states = value_states
|
||||
|
||||
# repeat_kv cannot be implemented because Broadcast op is needed
|
||||
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads")
|
||||
|
||||
if self.mode == "decode":
|
||||
key_states = self.concat(past_key, key_states, axis=-2)
|
||||
value_states = self.concat(past_value, value_states, axis=-2)
|
||||
|
||||
attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
|
||||
attn_weight = self.eltwise_add(attn_weight, attention_mask)
|
||||
attn_weight = self.convert_to_fp32(attn_weight)
|
||||
attn_weight = self.softmax(attn_weight, -1)
|
||||
attn_weight = self.convert_to_fp16(attn_weight)
|
||||
attn_output = self.matmul(attn_weight, value_states, False, False)
|
||||
|
||||
attn_output = self.transpose(attn_output, [0, 2, 1, 3])
|
||||
attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size])
|
||||
|
||||
attn_output = self.linear(attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype)
|
||||
|
||||
hidden_states = self.eltwise_add(residual, attn_output)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.convert_to_fp32(hidden_states)
|
||||
variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True)
|
||||
hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
|
||||
post_attention_layernorm_weight = self.constant(post_attention_layernorm_weight)
|
||||
post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight)
|
||||
hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states)
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
||||
# mlp
|
||||
mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
bias=False, wt_dtype=self.dtype)
|
||||
mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined]
|
||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
|
||||
hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype)
|
||||
|
||||
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
||||
return hidden_states, new_key_states, new_value_states
|
||||
|
||||
def rotate_half(self, x):
|
||||
x1 = self.slice(x, [0, 0, 0, 0], [self.batch_size, self.num_heads, self.seq_len, self.head_dim//2], )
|
||||
x2 = self.slice(x, [0, 0, 0, self.head_dim//2], [self.batch_size, self.num_heads, self.seq_len, self.head_dim])
|
||||
return self.concat(self.negative(x2), x1, axis=-1)
|
||||
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
|
||||
position_ids = self.squeeze(position_ids)
|
||||
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||
cos = self.unsqueeze(cos, [1])
|
||||
sin = self.unsqueeze(sin, [1])
|
||||
|
||||
q_embed = self.eltwise_add(self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin))
|
||||
k_embed = self.eltwise_add(self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin))
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parameters: List[Tuple[torch.Tensor]],
|
||||
input_laynorm_weights: List[torch.Tensor],
|
||||
post_attn_layernorm_weights: List[torch.Tensor],
|
||||
layer_indexes : List[int],
|
||||
cached_cos,
|
||||
cached_sin,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_key_value_heads: int,
|
||||
rms_norm_eps,
|
||||
intermediate_size,
|
||||
max_seq_len: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
op_parameters = []
|
||||
for w in parameters:
|
||||
if isinstance(w, tuple): # from QuantizedLinear
|
||||
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||
else:
|
||||
op_parameters.append(w.to(torch.float16).numpy())
|
||||
self.op_parameters = op_parameters
|
||||
self.op_id = str(uuid.uuid4())
|
||||
# self.layer_idx = layer_idx
|
||||
self.max_seq_len = max_seq_len
|
||||
# self.rotary_emb = rotary_emb
|
||||
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
|
||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||
else: # FP16 Linear
|
||||
invalidInputError(False, "Please use int4 optimization")
|
||||
|
||||
self.layer_indexes = layer_indexes
|
||||
|
||||
print("create dedcoder layer")
|
||||
self.backend_cls_decode = LowBitLlamaMultiDecoderlayer([1, 1, num_heads*head_dim],
|
||||
input_layernorm_weights=input_laynorm_weights,
|
||||
post_attn_layernorm_weights=post_attn_layernorm_weights,
|
||||
cached_cos=cached_cos,
|
||||
cached_sin=cached_sin,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_layers=len(layer_indexes),
|
||||
max_seq_len=max_seq_len,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
intermediate_size=intermediate_size,
|
||||
mode="decode",
|
||||
dtype=np_dtype)
|
||||
print("created dedcoder layer")
|
||||
|
||||
self.backend_cls_decode.setWeights(3+len(layer_indexes)*2, self.op_id, *op_parameters)
|
||||
print("weight setted")
|
||||
backend_lib.run(self.backend_cls_decode._mm,)
|
||||
print("first inference done")
|
||||
self.kv_cache_c_parameter_handel = None
|
||||
self.kv_cache_parameters = None
|
||||
self.kv_cache_prefetched = False
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,) -> torch.Tensor:
|
||||
"""Torch module forward method.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: result
|
||||
"""
|
||||
seq_len = hidden_states.shape[1]
|
||||
backend_cls = self.backend_cls_decode
|
||||
|
||||
pad_len = self.max_seq_len + 1 - attention_mask.size(-1)
|
||||
|
||||
pad_mask = (0, pad_len)
|
||||
padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask,
|
||||
value=torch.finfo(torch.float16).min)
|
||||
padded_attention_mask[:,:,:,-1] = 0.0
|
||||
inputs = (hidden_states.to(torch.float16),
|
||||
padded_attention_mask,
|
||||
position_ids,)
|
||||
|
||||
if self.kv_cache_parameters is None:
|
||||
self.kv_cache_parameters = []
|
||||
self.kv_cache_c_parameter_handel = None
|
||||
self.kv_cache_prefetched = False
|
||||
else:
|
||||
# the case kv cache changed
|
||||
cached_prt = self.kv_cache_parameters[0].storage().data_ptr()
|
||||
current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr()
|
||||
if cached_prt != current_ptr:
|
||||
self.kv_cache_parameters = []
|
||||
self.kv_cache_c_parameter_handel = None
|
||||
self.kv_cache_prefetched = False
|
||||
|
||||
if len(self.kv_cache_parameters) == 0:
|
||||
for idx in self.layer_indexes:
|
||||
past_key = past_key_value.key_cache[idx]
|
||||
past_value = past_key_value.value_cache[idx]
|
||||
new_size = (past_key.size(0),
|
||||
past_key.size(1),
|
||||
self.max_seq_len,
|
||||
past_key.size(3))
|
||||
past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0)
|
||||
past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0)
|
||||
|
||||
self.kv_cache_parameters.append(past_key)
|
||||
self.kv_cache_parameters.append(past_value)
|
||||
self.kv_cache_c_parameter_handel = self.backend_cls_decode.create_parameters([p.numpy() for p in self.kv_cache_parameters])
|
||||
|
||||
x_np = [elem.to(torch.float16).numpy() for elem in inputs]
|
||||
|
||||
with record_function(f"npu_factory"):
|
||||
if not self.kv_cache_prefetched:
|
||||
self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel)
|
||||
|
||||
for idx, elem in enumerate(x_np):
|
||||
self.backend_cls_decode.set_input_tensor(elem, idx)
|
||||
|
||||
backend_lib.run(self.backend_cls_decode._mm,)
|
||||
ret = self.backend_cls_decode.out
|
||||
results = [adapt_output_tensor(r, r.shape, torch.float16) for r in ret]
|
||||
|
||||
hidden_states = results[0]
|
||||
key_value_states = results[1:]
|
||||
|
||||
cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len}
|
||||
for i in range(len(self.layer_indexes)):
|
||||
key_states, value_states = past_key_value.update(key_value_states[2*i],
|
||||
key_value_states[2*i+1],
|
||||
self.layer_indexes[i], cache_kwargs)
|
||||
|
||||
self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel)
|
||||
self.kv_cache_prefetched = True
|
||||
outputs = (hidden_states,)
|
||||
outputs += (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
parameters: List[torch.Tensor],
|
||||
cached_cos,
|
||||
cached_sin,
|
||||
layer_norm_0,
|
||||
layer_norm_1,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
layer_idx: int,
|
||||
rms_norm_eps,
|
||||
intermediate_size,
|
||||
max_seq_len: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.op_parameters = parameters
|
||||
self.op_id = str(uuid.uuid4())
|
||||
self.layer_idx = layer_idx
|
||||
self.max_seq_len = max_seq_len
|
||||
# self.rotary_emb = rotary_emb
|
||||
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
|
||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||
else: # FP16 Linear
|
||||
np_dtype = np.float16
|
||||
|
||||
self.backend_cls_prefill = partial(LowBitLlamaDecoderlayer,
|
||||
cached_cos=cached_cos,
|
||||
cached_sin=cached_sin,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
max_seq_len=max_seq_len,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
intermediate_size=intermediate_size,
|
||||
mode="prefill",
|
||||
dtype=np_dtype)
|
||||
self.backend_cls_decode = partial(LowBitLlamaDecoderlayer,
|
||||
cached_cos=cached_cos,
|
||||
cached_sin=cached_sin,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
max_seq_len=max_seq_len,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
intermediate_size=intermediate_size,
|
||||
mode="decode",
|
||||
dtype=np_dtype)
|
||||
self.layer_norm_0 = layer_norm_0
|
||||
self.layer_norm_1 = layer_norm_1
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,) -> torch.Tensor:
|
||||
seq_len = hidden_states.shape[1]
|
||||
# cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
if seq_len == 1:
|
||||
backend_cls = self.backend_cls_decode
|
||||
past_key = past_key_value.key_cache[self.layer_idx]
|
||||
past_value = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
new_size = (past_key.size(0),
|
||||
past_key.size(1),
|
||||
self.max_seq_len,
|
||||
past_key.size(3))
|
||||
past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0)
|
||||
past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0)
|
||||
|
||||
pad_len = self.max_seq_len + 1 - attention_mask.size(-1)
|
||||
|
||||
pad_mask = (0, pad_len)
|
||||
padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask,
|
||||
value=torch.finfo(torch.float16).min)
|
||||
padded_attention_mask[:,:,:,-1] = 0.0
|
||||
inputs = (hidden_states.to(torch.float16),
|
||||
padded_attention_mask,
|
||||
position_ids,)
|
||||
|
||||
inputs += (self.layer_norm_0, self.layer_norm_1)
|
||||
|
||||
inputs += (past_key, past_value)
|
||||
hidden_states, new_key, new_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=4)
|
||||
cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len}
|
||||
key_states, value_states = past_key_value.update(new_key, new_value, self.layer_idx, cache_kwargs)
|
||||
else:
|
||||
backend_cls = self.backend_cls_prefill
|
||||
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids)
|
||||
inputs += (self.layer_norm_0, self.layer_norm_1)
|
||||
hidden_states, past_key, past_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=1)
|
||||
cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len}
|
||||
key_states, value_states = past_key_value.update(past_key, past_value, self.layer_idx, cache_kwargs)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
outputs += (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for npu model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
|
||||
help='The huggingface repo id for the Llama2 model to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
pipeline = True # default
|
||||
max_seq_len = 1024 # default
|
||||
if pipeline:
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '29501'
|
||||
|
||||
dist.init_process_group()
|
||||
my_rank = dist.get_rank()
|
||||
my_size = dist.get_world_size()
|
||||
logger.info(f"rank: {my_rank}, size: {my_size}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager",
|
||||
load_in_low_bit="sym_int4", pipeline_parallel_stages=2)
|
||||
|
||||
if my_rank == 0:
|
||||
print(model)
|
||||
dist.barrier()
|
||||
|
||||
if my_rank == 1:
|
||||
print(model)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager",
|
||||
load_in_low_bit="sym_int4")
|
||||
|
||||
if pipeline:
|
||||
layer_start = model.layer_start
|
||||
layer_end = model.layer_end
|
||||
num_layers = model.num_layers
|
||||
else:
|
||||
layer_start = 0
|
||||
layer_end = 32
|
||||
num_layers = 32
|
||||
num_heads = model.model.layers[layer_start].self_attn.num_heads
|
||||
num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
|
||||
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||
rms_norm_eps = model.config.rms_norm_eps
|
||||
intermediate_size = model.config.intermediate_size
|
||||
deocderlayers = []
|
||||
layer_weights = []
|
||||
input_layer_norm_weights = []
|
||||
post_attn_layernorm_weights = []
|
||||
layer_indexs = range(layer_start, layer_end)
|
||||
for layer_idx in layer_indexs:
|
||||
curr_layer = model.model.layers[layer_idx]
|
||||
attn_layer = curr_layer.self_attn
|
||||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = [
|
||||
# model.model.layers[i].input_layernorm.weight.to(torch.float16),
|
||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||
# model.model.layers[i].post_attention_layernorm.weight.to(torch.float16),
|
||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale)]
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||
|
||||
new_decoderlayer = FusedLlamaLowBitDecoderlayer(weights,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
cached_cos=cached_cos,
|
||||
cached_sin=cached_sin,
|
||||
# rotary_emb=model.model.layers[i].self_attn.rotary_emb,
|
||||
layer_norm_0=layer_norm_0,
|
||||
layer_norm_1=layer_norm_1,
|
||||
layer_idx=layer_idx,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
intermediate_size=intermediate_size,
|
||||
max_seq_len=max_seq_len)
|
||||
|
||||
layer_weights.extend(weights)
|
||||
input_layer_norm_weights.append(layer_norm_0)
|
||||
post_attn_layernorm_weights.append(layer_norm_1)
|
||||
model.model.layers[layer_idx] = new_decoderlayer
|
||||
|
||||
multi_decoder = FusedLlamaLowBitMultiDecoderlayer(
|
||||
parameters=layer_weights,
|
||||
input_laynorm_weights=input_layer_norm_weights,
|
||||
post_attn_layernorm_weights=post_attn_layernorm_weights,
|
||||
layer_indexes=layer_indexs,
|
||||
cached_cos=cached_cos,
|
||||
cached_sin=cached_sin,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
intermediate_size=intermediate_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
model.model.multi_decoder = multi_decoder
|
||||
print(model)
|
||||
|
||||
with torch.inference_mode():
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
|
||||
print("finish to load")
|
||||
print('input length:', len(input_ids[0]))
|
||||
for i in range(3):
|
||||
st = time.time()
|
||||
output = model.generate(input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict)
|
||||
end = time.time()
|
||||
if my_rank == 0:
|
||||
print(f"First token cost: {model.first_token_time} s, rest tokens cost average: {model.rest_cost_mean} s")
|
||||
print(f'Inference time: {end-st} s')
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
|
||||
print('-'*20, 'Prompt', '-'*20)
|
||||
print(args.prompt)
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
|
|
@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.utils import logger
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
|
||||
|
||||
|
||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||
|
|
@ -84,7 +84,7 @@ class _BaseAutoModelClass:
|
|||
warnings.warn("`device_map` will be ignored")
|
||||
kwargs['device_map'] = 'cpu'
|
||||
|
||||
if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float]:
|
||||
if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float, torch.float16]:
|
||||
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
||||
kwargs['torch_dtype'] = torch.float
|
||||
|
||||
|
|
@ -114,7 +114,7 @@ class _BaseAutoModelClass:
|
|||
ignore_argument(kwargs, "modules_to_not_convert")
|
||||
ignore_argument(kwargs, "quantization_config")
|
||||
ignore_argument(kwargs, "speculative")
|
||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||
pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
|
||||
|
||||
_args = copy.deepcopy(args)
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
|
|
@ -131,12 +131,28 @@ class _BaseAutoModelClass:
|
|||
|
||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||
|
||||
if pipeline_parallel_stages > 1:
|
||||
invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages,
|
||||
"Please make sure world size is same as `pipeline_parallel_stages`")
|
||||
kwargs['torch_dtype'] = torch.float16
|
||||
from .npu_models.pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
|
||||
model = pipeline_parallel(model, pipeline_parallel_stages,
|
||||
kwargs["torch_dtype"], device="cpu")
|
||||
|
||||
# add pipeline_parallel_generate to pretrained model dynamically
|
||||
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
|
||||
model)
|
||||
|
||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||
with torch.no_grad():
|
||||
optimize_llm(model)
|
||||
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
|
||||
create_npu_kernels(model)
|
||||
|
||||
if pipeline_parallel_stages == 1:
|
||||
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
|
||||
create_npu_kernels(model)
|
||||
else:
|
||||
cls.load_convert(qtype, model.model, 'cpu', *args, **kwargs)
|
||||
create_npu_kernels(model.model)
|
||||
optimize_llm_post(model)
|
||||
model = model.eval()
|
||||
|
||||
logger.info(f"Finish to convert model")
|
||||
|
|
|
|||
|
|
@ -63,7 +63,8 @@ def replace_with_QuantizedLinear(layer, qtype, device):
|
|||
(layer.in_features == 18944 and layer.out_features == 3584):
|
||||
qtype = "sym_int8_rtn"
|
||||
iqtype = ggml_tensor_qtype[qtype]
|
||||
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device)
|
||||
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
|
||||
iqtype, device=device)
|
||||
return QuantizedLinear(qweights, scale, layer.bias)
|
||||
|
||||
|
||||
|
|
@ -79,18 +80,22 @@ def optimize_llm(model: torch.nn.Module):
|
|||
if model.config.model_type == "llama":
|
||||
from ipex_llm.transformers.npu_models.llama import merge_qkv
|
||||
from ipex_llm.transformers.npu_models.llama import merge_mlp
|
||||
model.apply(merge_qkv)
|
||||
model.apply(merge_mlp)
|
||||
|
||||
from ipex_llm.transformers.npu_models.llama import llama_model_forward
|
||||
from ipex_llm.transformers.npu_models.llama import llama_fused_model_forward
|
||||
from ipex_llm.transformers.npu_models.llama import llama_attention_forward
|
||||
from ipex_llm.transformers.npu_models.llama import llama_mlp_forward
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
convert_forward(model, LlamaModel, llama_model_forward)
|
||||
convert_forward(model, LlamaAttention, llama_attention_forward)
|
||||
convert_forward(model, LlamaMLP, llama_mlp_forward)
|
||||
if hasattr(model, 'pipeline_parallel_stages'):
|
||||
# experimental support for fused decoderlayer implementation
|
||||
convert_forward(model, LlamaModel, llama_fused_model_forward)
|
||||
else:
|
||||
model.apply(merge_qkv)
|
||||
model.apply(merge_mlp)
|
||||
convert_forward(model, LlamaModel, llama_model_forward)
|
||||
convert_forward(model, LlamaAttention, llama_attention_forward)
|
||||
convert_forward(model, LlamaMLP, llama_mlp_forward)
|
||||
|
||||
elif model.config.model_type == "mistral":
|
||||
from ipex_llm.transformers.npu_models.mistral import merge_qkv
|
||||
|
|
@ -207,3 +212,28 @@ def optimize_llm(model: torch.nn.Module):
|
|||
from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward
|
||||
|
||||
convert_forward(model, module.Phi3Attention, phi3_attention_forward)
|
||||
|
||||
|
||||
def optimize_llm_post(model: torch.nn.Module):
|
||||
# experimental support for fused decoderlayer implementation
|
||||
if model.config.model_type == "llama":
|
||||
model.model.embed_tokens.to(torch.float32)
|
||||
model.model.norm.to(torch.float32)
|
||||
model.lm_head.to(torch.float32)
|
||||
|
||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
|
||||
ggml_tensor_qtype, FP4Params
|
||||
|
||||
if isinstance(model.lm_head, torch.nn.Linear):
|
||||
new_linear = LowBitLinear(model.lm_head.in_features,
|
||||
model.lm_head.out_features,
|
||||
ggml_tensor_qtype["sym_int4"],
|
||||
False)
|
||||
paramsLowBit = FP4Params(data=model.lm_head.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False,
|
||||
_shape=None,
|
||||
qtype=ggml_tensor_qtype["sym_int4"],
|
||||
in_features=model.lm_head.in_features).to("cpu")
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
model.lm_head = new_linear
|
||||
|
|
|
|||
115
python/llm/src/ipex_llm/transformers/npu_models/kv.py
Normal file
115
python/llm/src/ipex_llm/transformers/npu_models/kv.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import torch
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
|
||||
def init_fused_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
key_cache_storage = torch.zeros(batch_size, num_heads,
|
||||
max_length, head_dim,
|
||||
dtype=dtype, device=device)
|
||||
value_cache_storage = torch.zeros(batch_size, num_heads,
|
||||
max_length, head_dim,
|
||||
dtype=dtype, device=device)
|
||||
|
||||
key_cache = key_cache_storage.as_strided((batch_size, num_heads,
|
||||
current_length, head_dim),
|
||||
key_cache_storage.stride(),
|
||||
storage_offset=0)
|
||||
value_cache = value_cache_storage.as_strided((batch_size, num_heads,
|
||||
current_length, head_dim),
|
||||
value_cache_storage.stride(),
|
||||
storage_offset=0)
|
||||
return key_cache, value_cache
|
||||
|
||||
|
||||
def append_fused_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||
new_size = (cache_k.size(0),
|
||||
cache_k.size(1),
|
||||
cache_k.size(2) + key_states.size(2),
|
||||
cache_k.size(3))
|
||||
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
|
||||
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
|
||||
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
|
||||
new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states
|
||||
return new_cache_k, new_cache_v
|
||||
|
||||
|
||||
class DynamicFusedNormalCache(DynamicCache):
|
||||
# Experimental support for fused decoderlayer implementation on NPU
|
||||
# Currently only for llama2
|
||||
KV_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.key_cache: Dict[int, torch.Tensor] = {}
|
||||
self.value_cache: Dict[int, torch.Tensor] = {}
|
||||
self._seen_tokens = 0 # Used in `generate` to keep how many tokens the cache has seen
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||
|
||||
max_seq_length = cache_kwargs.pop("max_seq_len", None)
|
||||
transpose_value = cache_kwargs.pop("transpose_value", None)
|
||||
|
||||
if layer_idx == 0 or layer_idx == 16:
|
||||
if hasattr(self, "_seen_tokens"):
|
||||
# 4.39 uses `_seen_tokens`
|
||||
self._seen_tokens += seq_len
|
||||
else:
|
||||
# 4.37 uses `seen_tokens`
|
||||
self.seen_tokens += seq_len
|
||||
|
||||
# Update the cache
|
||||
# if len(self.key_cache) <= layer_idx:
|
||||
if layer_idx not in self.key_cache:
|
||||
max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + \
|
||||
self.KV_ALLOC_BLOCK_LENGTH
|
||||
k_cache, v_cache = init_fused_kv_cache(
|
||||
batch_size, num_heads, head_dim,
|
||||
0, max_len,
|
||||
key_states.dtype, key_states.device,
|
||||
)
|
||||
k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||
|
||||
self.key_cache[layer_idx] = k_cache
|
||||
self.value_cache[layer_idx] = v_cache
|
||||
else:
|
||||
k_cache = self.key_cache[layer_idx]
|
||||
v_cache = self.value_cache[layer_idx]
|
||||
|
||||
kv_seq_len = k_cache.size(2) + key_states.size(2)
|
||||
k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||
self.key_cache[layer_idx] = k_cache
|
||||
self.value_cache[layer_idx] = v_cache
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states.
|
||||
A layer index can be optionally passed."""
|
||||
|
||||
for idx, layer in self.key_cache.items():
|
||||
return layer.shape[-2]
|
||||
|
|
@ -182,6 +182,137 @@ def llama_model_forward(
|
|||
)
|
||||
|
||||
|
||||
def llama_fused_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
invalidInputError(False,
|
||||
("You cannot specify both input_ids and inputs_embeds at the same time, "
|
||||
"and must specify either one"))
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
|
||||
# ipex-llm changes start
|
||||
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
|
||||
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
|
||||
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device)
|
||||
# ipex-llm changes end
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
|
||||
cache_position, past_seen_tokens)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
seq_len = hidden_states.size(1)
|
||||
|
||||
if seq_len == 1:
|
||||
# multi_decoder = self.layers[(self.layer_end + 1) % num_layers]
|
||||
layer_outputs = self.multi_decoder(hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
next_decoder_cache = layer_outputs[1]
|
||||
else:
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
# ipex-llm changes start
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
# ipex-llm changes end
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache,
|
||||
all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def llama_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,639 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
|
||||
#
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Callable, List, Optional, Union, Tuple
|
||||
from types import SimpleNamespace
|
||||
import transformers
|
||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# patch GenerationMixin.generate
|
||||
from transformers import GenerationMixin
|
||||
original_generate = GenerationMixin.generate
|
||||
|
||||
|
||||
class DummyLayer(nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
|
||||
# python/llm/src/ipex_llm/transformers/models/llama.py#L2076
|
||||
self.weight = nn.Parameter(torch.empty(0,), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Dummy_MLPLayer(nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
|
||||
# python/llm/src/ipex_llm/transformers/models/llama.py#L119
|
||||
self.up_proj = DummyLayer()
|
||||
self.down_proj = DummyLayer()
|
||||
self.shared_expert = SimpleNamespace()
|
||||
self.shared_expert.up_proj = DummyLayer()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Dummy_DecoderLayer(nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
# to avoid AttributeError
|
||||
self.input_layernorm = DummyLayer()
|
||||
self.mlp = Dummy_MLPLayer()
|
||||
|
||||
def forward(self, hidden_states, *args, **kwargs):
|
||||
past_key_value = kwargs.get('past_key_value', None)
|
||||
use_cache = kwargs.get('use_cache', False)
|
||||
outputs = (hidden_states,)
|
||||
if use_cache:
|
||||
outputs += (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
class Dummy_GLMBlock(nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
# to avoid AttributeError
|
||||
self.input_layernorm = DummyLayer()
|
||||
self.mlp = Dummy_MLPLayer()
|
||||
|
||||
def forward(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
|
||||
):
|
||||
if kv_cache is None:
|
||||
return hidden_states, ()
|
||||
return hidden_states, kv_cache
|
||||
|
||||
|
||||
def init_pipeline_parallel():
|
||||
import oneccl_bindings_for_pytorch
|
||||
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||
dist.init_process_group('ccl')
|
||||
|
||||
|
||||
def low_mem_convert(model):
|
||||
from ipex_llm.transformers.convert import convert_forward
|
||||
import importlib
|
||||
if 'llama' in model.config.model_type:
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaForCausalLM,
|
||||
llama_causallm_forward_4_37_lowmem)
|
||||
elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"):
|
||||
if model.config.num_layers == 40:
|
||||
# for glm4-9b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
convert_forward(
|
||||
model,
|
||||
module.ChatGLMForConditionalGeneration,
|
||||
glm4_conditional_generation_forward_lowmem)
|
||||
else:
|
||||
# for chatglm3-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
convert_forward(
|
||||
model,
|
||||
module.ChatGLMForConditionalGeneration,
|
||||
chatglm3_conditional_generation_forward_lowmem)
|
||||
return model
|
||||
|
||||
|
||||
def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32, device=None):
|
||||
global num_layers
|
||||
if hasattr(model.config, 'num_hidden_layers'):
|
||||
num_layers = model.config.num_hidden_layers
|
||||
elif hasattr(model.config, 'num_layers'):
|
||||
# for chatglm3-6b
|
||||
num_layers = model.config.num_layers
|
||||
|
||||
slice_size = (num_layers + pipeline_parallel_stages - 1) // pipeline_parallel_stages
|
||||
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
global layer_start
|
||||
global layer_end
|
||||
layer_start = slice_size * local_rank
|
||||
layer_end = layer_start + min(slice_size, num_layers - layer_start)
|
||||
|
||||
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
|
||||
# for Qwen-VL-Chat
|
||||
for i in range(num_layers):
|
||||
if i < layer_start or i >= layer_end:
|
||||
model._modules['transformer'].h[i] = Dummy_DecoderLayer()
|
||||
if local_rank != 0:
|
||||
model._modules['transformer'].wte = DummyLayer()
|
||||
model._modules['transformer'].drop = DummyLayer()
|
||||
if local_rank != pipeline_parallel_stages - 1:
|
||||
model._modules['transformer'].ln_f = DummyLayer()
|
||||
model._modules['ln_f'] = DummyLayer()
|
||||
model._modules['lm_head'] = DummyLayer()
|
||||
elif model.config.model_type == "chatglm":
|
||||
# for chatglm3-6b, glm-4-9b-chat
|
||||
for i in range(num_layers):
|
||||
if i < layer_start or i >= layer_end:
|
||||
model._modules['transformer'].encoder.layers[i] = Dummy_GLMBlock()
|
||||
else:
|
||||
model._modules['transformer'].encoder.layers[i].self_attention.num_layers = \
|
||||
i - layer_start
|
||||
|
||||
if local_rank != 0:
|
||||
model._modules['transformer'].embedding = DummyLayer()
|
||||
if local_rank != pipeline_parallel_stages - 1:
|
||||
model._modules['transformer'].encoder.final_layernorm = DummyLayer()
|
||||
model._modules['transformer'].output_layer = DummyLayer()
|
||||
else:
|
||||
for i in range(num_layers):
|
||||
if i < layer_start or i >= layer_end:
|
||||
model._modules['model'].layers[i] = Dummy_DecoderLayer()
|
||||
else:
|
||||
model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
|
||||
|
||||
if local_rank != 0:
|
||||
model._modules['model'].embed_tokens = DummyLayer()
|
||||
if local_rank != pipeline_parallel_stages - 1:
|
||||
model._modules['model'].norm = DummyLayer()
|
||||
model._modules['lm_head'] = DummyLayer()
|
||||
|
||||
_enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM')
|
||||
_enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1")
|
||||
if _enable_lowmem:
|
||||
model = low_mem_convert(model)
|
||||
|
||||
model.pipeline_parallel_stages = pipeline_parallel_stages
|
||||
model.layer_start = layer_start
|
||||
model.layer_end = layer_end
|
||||
model.num_layers = num_layers
|
||||
if torch_dtype == torch.float16:
|
||||
model = model.half()
|
||||
if device is None:
|
||||
model = model.to(f'xpu:{local_rank}')
|
||||
else:
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
|
||||
# priority: `generation_config` argument > `model.generation_config`
|
||||
if generation_config is None:
|
||||
if (
|
||||
self.generation_config._from_model_config
|
||||
and self.generation_config._original_object_hash == hash(self.generation_config)
|
||||
and self.config._has_non_default_generation_parameters()
|
||||
):
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
self.generation_config = new_generation_config
|
||||
generation_config = self.generation_config
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning("Setting `pad_token_id` to `eos_token_id`: "
|
||||
f"{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
if generation_config is not None and generation_config.max_new_tokens is not None:
|
||||
max_new_tokens = generation_config.pop("max_new_tokens")
|
||||
else:
|
||||
max_new_tokens = kwargs.pop("max_new_tokens", None)
|
||||
|
||||
return self.pipeline_parallel_generate(inputs=inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
generation_config=generation_config,
|
||||
**kwargs)
|
||||
|
||||
return original_generate(self,
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
synced_gpus=synced_gpus,
|
||||
assistant_model=assistant_model,
|
||||
streamer=streamer,
|
||||
**kwargs)
|
||||
|
||||
GenerationMixin.generate = generate
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def pipeline_parallel_generate(self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
max_new_tokens: int = 32,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
**kwargs):
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
bs = inputs_tensor.shape[0]
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id)
|
||||
if self.config.is_encoder_decoder:
|
||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=bs,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
else:
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" \
|
||||
else model_kwargs.pop("input_ids")
|
||||
|
||||
local_rank = dist.get_rank()
|
||||
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
|
||||
next_rank = (local_rank + 1) % self.pipeline_parallel_stages
|
||||
|
||||
global layer_start
|
||||
global layer_end
|
||||
global num_layers
|
||||
|
||||
self.first_token_time = 0
|
||||
self.next_token_time = []
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) \
|
||||
if eos_token_id is not None else None
|
||||
|
||||
_input_ids = None
|
||||
_past_key_values = None
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
output_ids = input_ids.clone()
|
||||
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
|
||||
|
||||
step = 0
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
this_peer_finished = False
|
||||
while True:
|
||||
if step >= max_new_tokens:
|
||||
break
|
||||
|
||||
if _input_ids is None:
|
||||
_input_ids = input_ids
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs)
|
||||
|
||||
tic = time.time()
|
||||
if local_rank == 0:
|
||||
outputs = self(**model_inputs)
|
||||
else:
|
||||
_inputs_shape = _input_ids.shape + (self.config.hidden_size,)
|
||||
if step == 0 and self.config.model_type == "chatglm" \
|
||||
and hasattr(self.config, "vision_config"):
|
||||
# for glm-4v, image features are mapped during 1st token
|
||||
# 1597 are computed according to computation process of conv
|
||||
_images_feature = 1597 + _input_ids.shape[0] * 2 + _input_ids.shape[1]
|
||||
_inputs_shape = (_input_ids.shape[0], _images_feature, self.config.hidden_size,)
|
||||
inputs_embeds = torch.empty(_inputs_shape,
|
||||
device=input_ids.device, dtype=torch.float16)
|
||||
dist.recv(inputs_embeds, src=pre_rank)
|
||||
model_inputs.pop("input_ids")
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
outputs = self(**model_inputs)
|
||||
|
||||
if local_rank == self.pipeline_parallel_stages - 1:
|
||||
logits = outputs.logits
|
||||
next_ids = torch.argmax(logits[:, -1:, :], dim=-1)
|
||||
dist.broadcast(next_ids, src=local_rank)
|
||||
else:
|
||||
send_data = outputs[0].to(torch.float16)
|
||||
dist.send(send_data, dst=next_rank)
|
||||
next_ids = torch.empty((bs, 1), device=input_ids.device, dtype=torch.int64)
|
||||
dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1)
|
||||
|
||||
_input_ids = next_ids
|
||||
output_ids = torch.cat([output_ids, next_ids], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
next_ids = next_ids.squeeze()
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
invalidInputError(False, "If `eos_token_id` is defined, "
|
||||
"make sure that `pad_token_id` is defined.")
|
||||
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
if self.config.model_type == "chatglm" and self.config.num_layers == 40 \
|
||||
and not hasattr(self.config, "vision_config"):
|
||||
# for glm-4-9b-chat
|
||||
if step == 0:
|
||||
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
|
||||
past_key_values_placeholder = tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
) + (outputs.past_key_values)[: layer_end - layer_start] + tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
|
||||
)
|
||||
_past_key_values = past_key_values_placeholder
|
||||
else:
|
||||
_past_key_values = outputs.past_key_values
|
||||
elif self.config.model_type in ["baichuan", "chatglm"] or \
|
||||
(self.config.model_type == "qwen" and hasattr(self.config, "visual")):
|
||||
# for baichuan2, chatglm3, Qwen-VL-Chat, glm-4v-9b
|
||||
if local_rank != 0:
|
||||
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
|
||||
past_key_values_placeholder = tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
) + (outputs.past_key_values)[layer_start:]
|
||||
_past_key_values = past_key_values_placeholder
|
||||
else:
|
||||
_past_key_values = outputs.past_key_values
|
||||
else:
|
||||
_past_key_values = outputs.past_key_values
|
||||
|
||||
toc = time.time()
|
||||
if step == 0:
|
||||
self.first_token_time = toc - tic
|
||||
else:
|
||||
self.next_token_time.append(toc - tic)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_ids.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
if this_peer_finished:
|
||||
break
|
||||
|
||||
step += 1
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.synchronize()
|
||||
self.rest_cost_mean = np.mean(self.next_token_time)
|
||||
return output_ids
|
||||
|
||||
|
||||
def llama_causallm_forward_4_37_lowmem(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
# ipex-llm change starts
|
||||
|
||||
device = hidden_states.device
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) # noqa
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
logits = self.lm_head(hidden_states)
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
# logits = logits.float()
|
||||
|
||||
# ipex-llm change ends
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def chatglm3_conditional_generation_forward_lowmem(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
if return_last_logit:
|
||||
hidden_states = hidden_states[-1:]
|
||||
|
||||
device = hidden_states.device
|
||||
# ipex-llm change starts
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
lm_logits = self.transformer.output_layer(hidden_states)
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# lm_logits = lm_logits.to(torch.float32)
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
# ipex-llm change ends
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def glm4_conditional_generation_forward_lowmem(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
if return_last_logit:
|
||||
hidden_states = hidden_states[:, -1:]
|
||||
|
||||
device = hidden_states.device
|
||||
# ipex-llm change starts
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
lm_logits = self.transformer.output_layer(hidden_states)
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# lm_logits = lm_logits.to(torch.float32)
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
# ipex-llm change ends
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
Loading…
Reference in a new issue