Support qwen2-1.5b with fused decoderlayer optimization on NPU (#11888)
This commit is contained in:
parent
bdbe995b01
commit
72a7bf624b
6 changed files with 1119 additions and 15 deletions
|
|
@ -78,7 +78,7 @@ done
|
||||||
```
|
```
|
||||||
|
|
||||||
## Example 2: Predict Tokens using `generate()` API using multi processes
|
## Example 2: Predict Tokens using `generate()` API using multi processes
|
||||||
In the example [llama2.py](./llama2.py), we show an experimental support for a Llama2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimization and fused decoderlayer optimization on Intel NPUs.
|
In the example [llama2.py](./llama2.py) and [qwen2.py](./qwen2.py), we show an experimental support for a Llama2 / Qwen2 model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimization and fused decoderlayer optimization on Intel NPUs.
|
||||||
### 1. Install
|
### 1. Install
|
||||||
#### 1.1 Installation on Windows
|
#### 1.1 Installation on Windows
|
||||||
We suggest using conda to manage environment:
|
We suggest using conda to manage environment:
|
||||||
|
|
@ -111,7 +111,11 @@ set BIGDL_USE_NPU=1
|
||||||
### 3. Running examples
|
### 3. Running examples
|
||||||
|
|
||||||
```
|
```
|
||||||
|
# to run Llama-2-7b-chat-hf
|
||||||
python llama2.py
|
python llama2.py
|
||||||
|
|
||||||
|
# to run Qwen2-1.5B-Instruct
|
||||||
|
python qwen2.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments info:
|
Arguments info:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Predict Tokens using `generate()` API for npu model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id-or-model-path",
|
||||||
|
type=str,
|
||||||
|
default="Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
help="The huggingface repo id for the Qwen2 model to be downloaded"
|
||||||
|
", or the path to the huggingface checkpoint folder",
|
||||||
|
)
|
||||||
|
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||||
|
help='Prompt to infer')
|
||||||
|
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||||
|
parser.add_argument("--max-output-len", type=int, default=1024)
|
||||||
|
parser.add_argument("--max-prompt-len", type=int, default=512)
|
||||||
|
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
||||||
|
parser.add_argument("--intra-pp", type=int, default=2)
|
||||||
|
parser.add_argument("--inter-pp", type=int, default=1)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_path = args.repo_id_or_model_path
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
attn_implementation="eager",
|
||||||
|
load_in_low_bit="sym_int4",
|
||||||
|
enable_mp=True,
|
||||||
|
max_output_len=args.max_output_len,
|
||||||
|
max_prompt_len=args.max_prompt_len,
|
||||||
|
intra_pp=args.intra_pp,
|
||||||
|
inter_pp=args.inter_pp,
|
||||||
|
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
print("-" * 80)
|
||||||
|
print("done")
|
||||||
|
messages = [{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": args.prompt}]
|
||||||
|
text = tokenizer.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
with torch.inference_mode():
|
||||||
|
print("finish to load")
|
||||||
|
for i in range(3):
|
||||||
|
_input_ids = tokenizer([text], return_tensors="pt").input_ids
|
||||||
|
print("input length:", len(_input_ids[0]))
|
||||||
|
st = time.time()
|
||||||
|
output = model.generate(
|
||||||
|
_input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Inference time: {end-st} s")
|
||||||
|
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
|
||||||
|
print("-" * 20, "Input", "-" * 20)
|
||||||
|
print(input_str)
|
||||||
|
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
|
||||||
|
print("-" * 20, "Output", "-" * 20)
|
||||||
|
print(output_str)
|
||||||
|
|
||||||
|
print("-" * 80)
|
||||||
|
print("done")
|
||||||
|
print("success shut down")
|
||||||
|
|
@ -54,3 +54,26 @@ def optimize_llm(
|
||||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
)
|
)
|
||||||
convert_forward(model, LlamaModel, llama_model_forward)
|
convert_forward(model, LlamaModel, llama_model_forward)
|
||||||
|
elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960:
|
||||||
|
# for qwen2-1.5B
|
||||||
|
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
||||||
|
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
||||||
|
|
||||||
|
decode_runner = DecodeRunner(
|
||||||
|
model,
|
||||||
|
max_seq_len=max_output_len,
|
||||||
|
inter_pp=inter_pp,
|
||||||
|
intra_pp=intra_pp,
|
||||||
|
transpose_value_cache=transpose_value_cache,
|
||||||
|
)
|
||||||
|
prefill_runner = PrefillRunner(
|
||||||
|
model,
|
||||||
|
max_output_len=max_output_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
transpose_value_cache=transpose_value_cache,
|
||||||
|
)
|
||||||
|
qwen2_model_forward = gen_qwen2_fused_model_forward(
|
||||||
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
|
)
|
||||||
|
convert_forward(model, Qwen2Model, qwen2_model_forward)
|
||||||
|
|
|
||||||
|
|
@ -17,20 +17,11 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import argparse
|
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from intel_npu_acceleration_library.backend.factory import NNFactory
|
|
||||||
from typing import Optional, Sequence, List, Union, Any, Tuple
|
from typing import Optional, Sequence, List, Union, Any, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
|
||||||
from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function
|
|
||||||
from intel_npu_acceleration_library.backend.runtime import adapt_output_tensor, _model_cache
|
|
||||||
from collections import deque
|
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
|
|
||||||
import ctypes
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from typing import Optional, List, Generator
|
from typing import Optional, List, Generator
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -38,12 +29,10 @@ from functools import partial
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from filelock import FileLock
|
|
||||||
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
import gc
|
|
||||||
from colorama import Fore, Back, Style
|
from colorama import Fore, Back, Style
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,10 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
num_heads,
|
num_heads,
|
||||||
num_key_value_heads,
|
num_key_value_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
seq_len):
|
seq_len,
|
||||||
|
q_bias=None,
|
||||||
|
k_bias=None,
|
||||||
|
v_bias=None):
|
||||||
hidden_size = num_heads * head_dim
|
hidden_size = num_heads * head_dim
|
||||||
num_key_value_groups = num_heads // num_key_value_heads
|
num_key_value_groups = num_heads // num_key_value_heads
|
||||||
query_states = self.linear(
|
query_states = self.linear(
|
||||||
|
|
@ -128,6 +131,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
bias=False,
|
bias=False,
|
||||||
wt_dtype=self.dtype,
|
wt_dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
if q_bias is not None:
|
||||||
|
query_states = query_states + q_bias
|
||||||
key_states = self.linear(
|
key_states = self.linear(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
num_key_value_heads * head_dim,
|
num_key_value_heads * head_dim,
|
||||||
|
|
@ -135,6 +140,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
bias=False,
|
bias=False,
|
||||||
wt_dtype=self.dtype,
|
wt_dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
if k_bias is not None:
|
||||||
|
key_states = key_states + k_bias
|
||||||
value_states = self.linear(
|
value_states = self.linear(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
num_key_value_heads * head_dim,
|
num_key_value_heads * head_dim,
|
||||||
|
|
@ -142,6 +149,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
bias=False,
|
bias=False,
|
||||||
wt_dtype=self.dtype,
|
wt_dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
if v_bias is not None:
|
||||||
|
value_states = value_states + v_bias
|
||||||
|
|
||||||
query_states = self.reshape(
|
query_states = self.reshape(
|
||||||
query_states, [1, seq_len, num_heads, head_dim]
|
query_states, [1, seq_len, num_heads, head_dim]
|
||||||
|
|
@ -192,7 +201,8 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
n_rep=num_key_value_groups,
|
n_rep=num_key_value_groups,
|
||||||
num_key_value_heads=num_key_value_heads,
|
num_key_value_heads=num_key_value_heads,
|
||||||
kv_seq_len=kv_seq_len,
|
kv_seq_len=kv_seq_len,
|
||||||
head_dim=head_dim,)
|
head_dim=head_dim,
|
||||||
|
transpose=self.transpose_value)
|
||||||
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
||||||
math.sqrt(head_dim)
|
math.sqrt(head_dim)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
983
python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Normal file
983
python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Normal file
|
|
@ -0,0 +1,983 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
|
from typing import Optional, Sequence, List, Union, Any, Tuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
from typing import Optional, List, Generator
|
||||||
|
import uuid
|
||||||
|
from functools import partial
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
from colorama import Fore, Back, Style
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
||||||
|
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||||
|
|
||||||
|
|
||||||
|
class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# batch_size: int,
|
||||||
|
# seq_len: int,
|
||||||
|
# hidden_size: int,
|
||||||
|
hidden_shape: Sequence[int],
|
||||||
|
*shapes,
|
||||||
|
num_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
num_layers: int,
|
||||||
|
cached_cos,
|
||||||
|
cached_sin,
|
||||||
|
input_layernorm_weights=None,
|
||||||
|
post_attn_layernorm_weights=None,
|
||||||
|
q_biases=None,
|
||||||
|
k_biases=None,
|
||||||
|
v_biases=None,
|
||||||
|
mode: str = "prefill",
|
||||||
|
dtype: np.dtype = np.int8,
|
||||||
|
max_seq_len: int = 1024,
|
||||||
|
transpose_value: bool = False,
|
||||||
|
profile: bool = False,
|
||||||
|
device: str = "NPU",
|
||||||
|
rms_norm_eps,
|
||||||
|
intermediate_size,
|
||||||
|
):
|
||||||
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
|
transpose_value=transpose_value,
|
||||||
|
dtype=dtype,
|
||||||
|
profile=profile,
|
||||||
|
device=device)
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.cached_cos = cached_cos
|
||||||
|
self.cached_sin = cached_sin
|
||||||
|
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
|
||||||
|
self.mode = mode
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.transpose_value = transpose_value
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
cos = self.constant(self.cached_cos)
|
||||||
|
self.cos = self.unsqueeze(cos, axis=0)
|
||||||
|
|
||||||
|
sin = self.constant(self.cached_sin)
|
||||||
|
self.sin = self.unsqueeze(sin, axis=0)
|
||||||
|
|
||||||
|
if mode == "decode":
|
||||||
|
self.kv_seq_len = self.max_seq_len + 1
|
||||||
|
else:
|
||||||
|
self.kv_seq_len = self.seq_len
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
|
||||||
|
# define input, the order self.parameter matters
|
||||||
|
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
if mode == "decode":
|
||||||
|
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1))
|
||||||
|
else:
|
||||||
|
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len))
|
||||||
|
|
||||||
|
position_ids = self.create_input_op((self.batch_size, self.seq_len))
|
||||||
|
past_keys = []
|
||||||
|
past_values = []
|
||||||
|
if mode == "decode":
|
||||||
|
for i in range(num_layers):
|
||||||
|
past_key = self.create_cache_op(
|
||||||
|
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
|
||||||
|
)
|
||||||
|
if transpose_value:
|
||||||
|
past_value = self.create_cache_op(
|
||||||
|
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
past_value = self.create_cache_op(
|
||||||
|
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
|
||||||
|
)
|
||||||
|
past_keys.append(past_key)
|
||||||
|
past_values.append(past_value)
|
||||||
|
else:
|
||||||
|
past_keys = [None] * num_layers
|
||||||
|
past_values = [None] * num_layers
|
||||||
|
|
||||||
|
if input_layernorm_weights is None:
|
||||||
|
input_layernorm_weights = []
|
||||||
|
post_attn_layernorm_weights = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
input_layernorm_weights.append(
|
||||||
|
self.create_input_op(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.hidden_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
post_attn_layernorm_weights.append(
|
||||||
|
self.create_input_op(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.hidden_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights]
|
||||||
|
post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights]
|
||||||
|
|
||||||
|
if q_biases is None:
|
||||||
|
q_biases = []
|
||||||
|
k_biases = []
|
||||||
|
v_biases = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
q_biases.append(self.create_input_op((self.num_heads * self.head_dim,)))
|
||||||
|
k_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,)))
|
||||||
|
v_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,)))
|
||||||
|
else:
|
||||||
|
q_biases = [self.constant(w) for w in q_biases]
|
||||||
|
k_biases = [self.constant(w) for w in k_biases]
|
||||||
|
v_biases = [self.constant(w) for w in v_biases]
|
||||||
|
|
||||||
|
hidden_states = input
|
||||||
|
|
||||||
|
curr_key_values = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
hidden_states, new_key_states, new_value_states = self.build_decoder(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
input_layernorm_weight=input_layernorm_weights[i],
|
||||||
|
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
||||||
|
q_bias=q_biases[i],
|
||||||
|
k_bias=k_biases[i],
|
||||||
|
v_bias=v_biases[i],
|
||||||
|
past_key=past_keys[i],
|
||||||
|
past_value=past_values[i],
|
||||||
|
)
|
||||||
|
curr_key_values.append((new_key_states, new_value_states))
|
||||||
|
|
||||||
|
# define outputs
|
||||||
|
hidden_states = self.convert_to_fp16(hidden_states)
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
new_key_states = self.convert_to_fp16(curr_key_values[i][0])
|
||||||
|
new_value_states = self.convert_to_fp16(curr_key_values[i][1])
|
||||||
|
|
||||||
|
self.compile()
|
||||||
|
|
||||||
|
def build_decoder(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
input_layernorm_weight,
|
||||||
|
post_attention_layernorm_weight,
|
||||||
|
q_bias,
|
||||||
|
k_bias,
|
||||||
|
v_bias,
|
||||||
|
past_key=None,
|
||||||
|
past_value=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size))
|
||||||
|
input_2d = self.layer_norm(input_2d, input_layernorm_weight)
|
||||||
|
attn_output, new_key_states, new_value_states = self.attention(
|
||||||
|
hidden_states=input_2d,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key=past_key,
|
||||||
|
past_value=past_value,
|
||||||
|
cos=self.cos,
|
||||||
|
sin=self.sin,
|
||||||
|
mode=self.mode,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
seq_len=self.seq_len,
|
||||||
|
q_bias=q_bias,
|
||||||
|
k_bias=k_bias,
|
||||||
|
v_bias=v_bias,
|
||||||
|
)
|
||||||
|
hidden_states = self.eltwise_add(residual, attn_output)
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||||
|
hidden_states = self.convert_to_fp16(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, new_key_states, new_value_states
|
||||||
|
|
||||||
|
|
||||||
|
class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parameters: List[Tuple[torch.Tensor]],
|
||||||
|
input_laynorm_weights: List[torch.Tensor],
|
||||||
|
post_attn_layernorm_weights: List[torch.Tensor],
|
||||||
|
q_biases: List[torch.Tensor],
|
||||||
|
k_biases: List[torch.Tensor],
|
||||||
|
v_biases: List[torch.Tensor],
|
||||||
|
layer_indexes: List[int],
|
||||||
|
intra_stages: int,
|
||||||
|
cached_cos: torch.Tensor,
|
||||||
|
cached_sin: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
rms_norm_eps,
|
||||||
|
intermediate_size,
|
||||||
|
max_seq_len: int = 1024,
|
||||||
|
transpose_value: bool = False,
|
||||||
|
do_print: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.do_print = do_print
|
||||||
|
|
||||||
|
op_parameters = []
|
||||||
|
for w in parameters:
|
||||||
|
if isinstance(w, tuple): # from QuantizedLinear
|
||||||
|
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||||
|
else:
|
||||||
|
op_parameters.append(w.to(torch.float16).numpy())
|
||||||
|
self.op_parameters = op_parameters
|
||||||
|
self.op_id = str(uuid.uuid4())
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.transpose_value = transpose_value
|
||||||
|
if isinstance(parameters[0], tuple):
|
||||||
|
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||||
|
else: # FP16 Linear
|
||||||
|
np_dtype = np.float16
|
||||||
|
|
||||||
|
self.intra_stages = intra_stages
|
||||||
|
self.layer_indexes = layer_indexes
|
||||||
|
num_layers = len(self.layer_indexes) // intra_stages
|
||||||
|
self.layer_ranges = []
|
||||||
|
for i in range(intra_stages):
|
||||||
|
if i == intra_stages - 1:
|
||||||
|
self.layer_ranges.append((i * num_layers, len(self.layer_indexes)))
|
||||||
|
else:
|
||||||
|
self.layer_ranges.append((i * num_layers, (i + 1) * num_layers))
|
||||||
|
|
||||||
|
self.backend_decoders = []
|
||||||
|
|
||||||
|
for i in range(intra_stages):
|
||||||
|
start, end = self.layer_ranges[i]
|
||||||
|
lm_0 = input_laynorm_weights[start:end]
|
||||||
|
lm_1 = post_attn_layernorm_weights[start:end]
|
||||||
|
decoder = LowBitQwenMultiDecoderlayer(
|
||||||
|
[1, 1, num_heads * head_dim],
|
||||||
|
input_layernorm_weights=lm_0,
|
||||||
|
post_attn_layernorm_weights=lm_1,
|
||||||
|
q_biases=q_biases[start:end],
|
||||||
|
k_biases=k_biases[start:end],
|
||||||
|
v_biases=v_biases[start:end],
|
||||||
|
cached_cos=cached_cos,
|
||||||
|
cached_sin=cached_sin,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
num_layers=end - start,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
mode="decode",
|
||||||
|
transpose_value=self.transpose_value,
|
||||||
|
dtype=np_dtype,
|
||||||
|
)
|
||||||
|
self.backend_decoders.append(decoder)
|
||||||
|
|
||||||
|
for i in range(intra_stages):
|
||||||
|
start, end = self.layer_ranges[i]
|
||||||
|
self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 7:end * 7])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
inputs = (
|
||||||
|
hidden_states.to(torch.float16),
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(self.intra_stages):
|
||||||
|
start, end = self.layer_ranges[i]
|
||||||
|
self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])
|
||||||
|
|
||||||
|
hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
|
||||||
|
inputs,
|
||||||
|
decoders=self.backend_decoders)
|
||||||
|
|
||||||
|
if self.do_print:
|
||||||
|
print("outputs:", hidden_states)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
outputs += (past_key_value, new_keys, new_values)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def post_forward(self, past_key_value, new_keys, new_values):
|
||||||
|
key_value_states = []
|
||||||
|
for i in range(self.intra_stages):
|
||||||
|
for j in range(1, len(self.backend_decoders[i].torch_out)):
|
||||||
|
key_value_states.append(self.backend_decoders[i].torch_out[j])
|
||||||
|
|
||||||
|
cache_kwargs = {
|
||||||
|
"max_seq_len": self.max_seq_len,
|
||||||
|
"transpose": self.transpose_value,
|
||||||
|
}
|
||||||
|
for i in range(len(self.layer_indexes)):
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
new_keys[i],
|
||||||
|
new_values[i],
|
||||||
|
self.layer_indexes[i],
|
||||||
|
cache_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(self.intra_stages):
|
||||||
|
self.backend_decoders[i].load_cache_async()
|
||||||
|
|
||||||
|
|
||||||
|
class FusedQwenLowBitDecoderlayer(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parameters: List[torch.Tensor],
|
||||||
|
cached_cos,
|
||||||
|
cached_sin,
|
||||||
|
layer_norm_0,
|
||||||
|
layer_norm_1,
|
||||||
|
q_bias,
|
||||||
|
k_bias,
|
||||||
|
v_bias,
|
||||||
|
num_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
layer_idx: int,
|
||||||
|
rms_norm_eps,
|
||||||
|
intermediate_size,
|
||||||
|
max_seq_len: int = 128,
|
||||||
|
transpose_value: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.op_parameters = parameters
|
||||||
|
self.op_id = str(uuid.uuid4())
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.transpose_value = transpose_value
|
||||||
|
# self.rotary_emb = rotary_emb
|
||||||
|
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
|
||||||
|
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||||
|
else: # FP16 Linear
|
||||||
|
np_dtype = np.float16
|
||||||
|
|
||||||
|
self.backend_cls_prefill = partial(
|
||||||
|
LowBitQwenMultiDecoderlayer,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
num_layers=1,
|
||||||
|
cached_cos=cached_cos,
|
||||||
|
cached_sin=cached_sin,
|
||||||
|
input_layernorm_weights=None,
|
||||||
|
post_attn_layernorm_weights=None,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
mode="prefill",
|
||||||
|
transpose_value=self.transpose_value,
|
||||||
|
dtype=np_dtype,
|
||||||
|
)
|
||||||
|
self.layer_norm_0 = layer_norm_0
|
||||||
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
self.q_bias = q_bias
|
||||||
|
self.k_bias = k_bias
|
||||||
|
self.v_bias = v_bias
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Torch module forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: result
|
||||||
|
"""
|
||||||
|
|
||||||
|
seq_len = hidden_states.shape[1]
|
||||||
|
|
||||||
|
backend_cls = self.backend_cls_prefill
|
||||||
|
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids)
|
||||||
|
inputs += (self.layer_norm_0, self.layer_norm_1)
|
||||||
|
inputs += (self.q_bias, self.k_bias, self.v_bias)
|
||||||
|
hidden_states, past_key, past_value = run_model(
|
||||||
|
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
|
||||||
|
)
|
||||||
|
cache_kwargs = {"max_seq_len": self.max_seq_len, "transpose": self.transpose_value}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
past_key, past_value, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
outputs += (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def run_decode(
|
||||||
|
model,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
port,
|
||||||
|
layer_start,
|
||||||
|
layer_end,
|
||||||
|
intra_stages,
|
||||||
|
max_seq_len,
|
||||||
|
transpose_value_cache,
|
||||||
|
input_queue,
|
||||||
|
result_queue,
|
||||||
|
):
|
||||||
|
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = port
|
||||||
|
os.environ["RANK"] = str(rank)
|
||||||
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
|
|
||||||
|
print("start init process group, rank: ", rank, "world_size: ", world_size)
|
||||||
|
|
||||||
|
dist.init_process_group()
|
||||||
|
my_rank = dist.get_rank()
|
||||||
|
my_size = dist.get_world_size()
|
||||||
|
logger.info(f"rank: {my_rank}, size: {my_size}")
|
||||||
|
|
||||||
|
num_heads = model.model.layers[layer_start].self_attn.num_heads
|
||||||
|
num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
|
||||||
|
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||||
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
|
intermediate_size = model.config.intermediate_size
|
||||||
|
deocderlayers = []
|
||||||
|
layer_weights = []
|
||||||
|
input_layer_norm_weights = []
|
||||||
|
post_attn_layernorm_weights = []
|
||||||
|
q_biases = []
|
||||||
|
k_biases = []
|
||||||
|
v_biases = []
|
||||||
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
for layer_idx in layer_indexs:
|
||||||
|
curr_layer = model.model.layers[layer_idx]
|
||||||
|
attn_layer = curr_layer.self_attn
|
||||||
|
mlp_layer = curr_layer.mlp
|
||||||
|
|
||||||
|
weights = [
|
||||||
|
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
||||||
|
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
||||||
|
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
||||||
|
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||||
|
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||||
|
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||||
|
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||||
|
]
|
||||||
|
|
||||||
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
|
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||||
|
|
||||||
|
layer_weights.extend(weights)
|
||||||
|
input_layer_norm_weights.append(layer_norm_0)
|
||||||
|
post_attn_layernorm_weights.append(layer_norm_1)
|
||||||
|
q_biases.append(attn_layer.q_proj.bias.to(torch.float16))
|
||||||
|
k_biases.append(attn_layer.k_proj.bias.to(torch.float16))
|
||||||
|
v_biases.append(attn_layer.v_proj.bias.to(torch.float16))
|
||||||
|
|
||||||
|
multi_decoder = FusedQwenLowBitMultiDecoderlayer(
|
||||||
|
parameters=layer_weights,
|
||||||
|
input_laynorm_weights=input_layer_norm_weights,
|
||||||
|
post_attn_layernorm_weights=post_attn_layernorm_weights,
|
||||||
|
q_biases=q_biases,
|
||||||
|
k_biases=k_biases,
|
||||||
|
v_biases=v_biases,
|
||||||
|
layer_indexes=layer_indexs,
|
||||||
|
intra_stages=intra_stages,
|
||||||
|
cached_cos=cached_cos,
|
||||||
|
cached_sin=cached_sin,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
transpose_value=transpose_value_cache,
|
||||||
|
do_print=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
|
control = torch.empty((), dtype=torch.int)
|
||||||
|
hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16)
|
||||||
|
with torch.inference_mode():
|
||||||
|
while True:
|
||||||
|
|
||||||
|
dist.broadcast(control, src=0)
|
||||||
|
if control.item() == -2:
|
||||||
|
break
|
||||||
|
elif control.item() == -1:
|
||||||
|
past_key_values = input_queue.get()
|
||||||
|
else:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_seen_tokens,
|
||||||
|
1 + past_seen_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, 1)
|
||||||
|
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(hidden_states.shape[0], hidden_states.shape[1]),
|
||||||
|
hidden_states,
|
||||||
|
past_seen_tokens,
|
||||||
|
sliding_window=model.model.config.sliding_window,
|
||||||
|
)
|
||||||
|
pad_len = multi_decoder.max_seq_len + 1 - causal_mask.size(-1)
|
||||||
|
|
||||||
|
causal_mask[:, :, :, -1] = torch.finfo(torch.float16).min
|
||||||
|
pad_mask = (0, pad_len)
|
||||||
|
padded_causal_mask = F.pad(
|
||||||
|
causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
|
||||||
|
)
|
||||||
|
padded_causal_mask[:, :, :, -1] = 0.0
|
||||||
|
dist.recv(hidden_states, src=rank - 1)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
layer_outputs = multi_decoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=padded_causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=False,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
t3 = time.perf_counter()
|
||||||
|
dist.send(hidden_states, dst=(rank + 1) % world_size)
|
||||||
|
t4 = time.perf_counter()
|
||||||
|
past_key_values = layer_outputs[1]
|
||||||
|
new_keys = layer_outputs[2]
|
||||||
|
new_values = layer_outputs[3]
|
||||||
|
multi_decoder.post_forward(past_key_values, new_keys, new_values)
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeRunner:
|
||||||
|
def __init__(self, model, max_seq_len, intra_pp=2, inter_pp=2, transpose_value_cache=True):
|
||||||
|
self.model = model
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.transpose_value_cache = transpose_value_cache
|
||||||
|
world_size = inter_pp + 1
|
||||||
|
intra_stages = intra_pp
|
||||||
|
num_layers = self.model.model.config.num_hidden_layers
|
||||||
|
|
||||||
|
port = "54791"
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = port
|
||||||
|
os.environ["RANK"] = "0"
|
||||||
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
|
|
||||||
|
self.input_queues = []
|
||||||
|
self.output_queues = []
|
||||||
|
self.decoder_processes = []
|
||||||
|
|
||||||
|
for rank in range(1, world_size):
|
||||||
|
input_q = mp.Queue()
|
||||||
|
output_q = mp.Queue()
|
||||||
|
start_layer = (rank - 1) * (num_layers // (world_size - 1))
|
||||||
|
end_layer = (rank) * (num_layers // (world_size - 1))
|
||||||
|
if rank == world_size - 1:
|
||||||
|
end_layer = num_layers
|
||||||
|
p = mp.Process(
|
||||||
|
target=run_decode,
|
||||||
|
args=(
|
||||||
|
self.model,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
port,
|
||||||
|
start_layer,
|
||||||
|
end_layer,
|
||||||
|
intra_stages,
|
||||||
|
self.max_seq_len,
|
||||||
|
self.transpose_value_cache,
|
||||||
|
input_q,
|
||||||
|
output_q,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.daemon = True
|
||||||
|
p.start()
|
||||||
|
self.input_queues.append(input_q)
|
||||||
|
self.output_queues.append(output_q)
|
||||||
|
self.decoder_processes.append(p)
|
||||||
|
|
||||||
|
dist.init_process_group()
|
||||||
|
my_rank = dist.get_rank()
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
logger.info(f"rank: {my_rank}, size: {self.world_size}")
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
self.cache_past_key_value = None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
if self.cache_past_key_value != past_key_value:
|
||||||
|
control = torch.tensor(-1, dtype=torch.int)
|
||||||
|
dist.broadcast(control, src=0)
|
||||||
|
for i in range(len(self.decoder_processes)):
|
||||||
|
self.input_queues[i].put(past_key_value)
|
||||||
|
|
||||||
|
control = torch.tensor(0, dtype=torch.int)
|
||||||
|
dist.broadcast(control, src=0)
|
||||||
|
hidden_states = hidden_states.to(torch.float16)
|
||||||
|
dist.send(hidden_states, dst=1)
|
||||||
|
past_key_value.expand(self.transpose_value_cache)
|
||||||
|
dist.recv(hidden_states, src=self.world_size - 1)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
return hidden_states, past_key_value
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
control = torch.tensor(-2, dtype=torch.int)
|
||||||
|
dist.broadcast(control, src=0)
|
||||||
|
for p in self.decoder_processes:
|
||||||
|
p.join(3)
|
||||||
|
for p in self.decoder_processes:
|
||||||
|
if p.exitcode is None:
|
||||||
|
p.kill()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def run_prefill(
|
||||||
|
model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue
|
||||||
|
):
|
||||||
|
|
||||||
|
layer_start = 0
|
||||||
|
layer_end = len(model.model.layers)
|
||||||
|
num_heads = model.model.layers[layer_start].self_attn.num_heads
|
||||||
|
num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
|
||||||
|
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||||
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
|
intermediate_size = model.config.intermediate_size
|
||||||
|
deocderlayers = []
|
||||||
|
layer_weights = []
|
||||||
|
input_layer_norm_weights = []
|
||||||
|
post_attn_layernorm_weights = []
|
||||||
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
for layer_idx in layer_indexs:
|
||||||
|
curr_layer = model.model.layers[layer_idx]
|
||||||
|
attn_layer = curr_layer.self_attn
|
||||||
|
mlp_layer = curr_layer.mlp
|
||||||
|
|
||||||
|
weights = [
|
||||||
|
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
||||||
|
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
||||||
|
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
||||||
|
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||||
|
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||||
|
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||||
|
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||||
|
]
|
||||||
|
|
||||||
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
|
||||||
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
|
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||||
|
|
||||||
|
new_decoderlayer = FusedQwenLowBitDecoderlayer(
|
||||||
|
weights,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
cached_cos=cached_cos,
|
||||||
|
cached_sin=cached_sin,
|
||||||
|
layer_norm_0=layer_norm_0,
|
||||||
|
layer_norm_1=layer_norm_1,
|
||||||
|
q_bias=attn_layer.q_proj.bias.to(torch.float16),
|
||||||
|
k_bias=attn_layer.k_proj.bias.to(torch.float16),
|
||||||
|
v_bias=attn_layer.v_proj.bias.to(torch.float16),
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
max_seq_len=max_output_len,
|
||||||
|
transpose_value=transpose_value_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_weights.extend(weights)
|
||||||
|
input_layer_norm_weights.append(layer_norm_0)
|
||||||
|
post_attn_layernorm_weights.append(layer_norm_1)
|
||||||
|
model.model.layers[layer_idx] = new_decoderlayer
|
||||||
|
deocderlayers.append(new_decoderlayer)
|
||||||
|
|
||||||
|
print("finish creating all decode layers in prefill")
|
||||||
|
result_queue.put("loading finish")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
result = input_queue.get()
|
||||||
|
if result == "stop":
|
||||||
|
break
|
||||||
|
|
||||||
|
hidden_states, position_ids, causal_mask, past_key_values = result
|
||||||
|
with torch.inference_mode():
|
||||||
|
for decoder_layer in deocderlayers:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=False,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
next_decoder_cache = layer_outputs[1]
|
||||||
|
|
||||||
|
result_queue.put((hidden_states, next_decoder_cache))
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillRunner:
|
||||||
|
def __init__(self, model, max_output_len, max_prompt_len, transpose_value_cache):
|
||||||
|
self.model = model
|
||||||
|
self.max_output_len = max_output_len
|
||||||
|
self.max_prompt_len = max_prompt_len
|
||||||
|
self.transpose_value_cache = transpose_value_cache
|
||||||
|
|
||||||
|
self.prefill_result_queue = mp.Queue()
|
||||||
|
self.prefill_input_queue = mp.Queue()
|
||||||
|
|
||||||
|
self.p = mp.Process(
|
||||||
|
target=run_prefill,
|
||||||
|
args=(
|
||||||
|
model,
|
||||||
|
max_output_len,
|
||||||
|
max_prompt_len,
|
||||||
|
transpose_value_cache,
|
||||||
|
self.prefill_input_queue,
|
||||||
|
self.prefill_result_queue,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.p.daemon = True
|
||||||
|
self.p.start()
|
||||||
|
output = self.prefill_result_queue.get()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
seq_len = hidden_states.size(1)
|
||||||
|
invalidInputError(
|
||||||
|
seq_len <= self.max_prompt_len,
|
||||||
|
(
|
||||||
|
f"seq_len: {seq_len} should be less than or equal"
|
||||||
|
" to max_prompt_len {self.max_prompt_len}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.prefill_input_queue.put((hidden_states, position_ids, attention_mask, past_key_value))
|
||||||
|
return self.prefill_result_queue.get()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.prefill_input_queue.put("stop")
|
||||||
|
self.p.join(3)
|
||||||
|
if self.p.exitcode is None:
|
||||||
|
self.p.kill()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def gen_qwen2_fused_model_forward(prefill_runner, decode_runner):
|
||||||
|
|
||||||
|
def qwen2_fused_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"You cannot specify both decoder_input_ids and "
|
||||||
|
"decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
"You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
|
||||||
|
|
||||||
|
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
|
||||||
|
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
if seq_length == 1:
|
||||||
|
layers_runner = decode_runner
|
||||||
|
else:
|
||||||
|
layers_runner = prefill_runner
|
||||||
|
layer_outputs = layers_runner.forward(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
next_decoder_cache = layer_outputs[1]
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
return qwen2_fused_model_forward
|
||||||
Loading…
Reference in a new issue