[NPU] Add Optimized Support for Llama3.2-1B/3B on NPU (#12339)
* Add initial support for llama3.2-1b/3b * move llama3.2 support into current llama_mp impl
This commit is contained in:
parent
872a74481a
commit
a7b66683f1
6 changed files with 360 additions and 127 deletions
|
|
@ -7,6 +7,8 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
|
||||||
|------------|----------------------------------------------------------------|
|
|------------|----------------------------------------------------------------|
|
||||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
|
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
|
||||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
|
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
|
||||||
|
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
|
||||||
|
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
|
||||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
||||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
||||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
||||||
|
|
@ -33,6 +35,9 @@ conda activate llm
|
||||||
|
|
||||||
:: install ipex-llm with 'npu' option
|
:: install ipex-llm with 'npu' option
|
||||||
pip install --pre --upgrade ipex-llm[npu]
|
pip install --pre --upgrade ipex-llm[npu]
|
||||||
|
|
||||||
|
:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct
|
||||||
|
pip install transformers==4.45.0 accelerate==0.33.0
|
||||||
```
|
```
|
||||||
|
|
||||||
## 2. Runtime Configurations
|
## 2. Runtime Configurations
|
||||||
|
|
@ -82,6 +87,8 @@ done
|
||||||
The examples below show how to run the **_optimized HuggingFace model implementations_** on Intel NPU, including
|
The examples below show how to run the **_optimized HuggingFace model implementations_** on Intel NPU, including
|
||||||
- [Llama2-7B](./llama.py)
|
- [Llama2-7B](./llama.py)
|
||||||
- [Llama3-8B](./llama.py)
|
- [Llama3-8B](./llama.py)
|
||||||
|
- [Llama3.2-1B](./llama.py)
|
||||||
|
- [Llama3.2-3B](./llama.py)
|
||||||
- [Qwen2-1.5B](./qwen.py)
|
- [Qwen2-1.5B](./qwen.py)
|
||||||
- [Qwen2.5-7B](./qwen.py)
|
- [Qwen2.5-7B](./qwen.py)
|
||||||
- [MiniCPM-1B](./minicpm.py)
|
- [MiniCPM-1B](./minicpm.py)
|
||||||
|
|
@ -106,6 +113,12 @@ python llama.py
|
||||||
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
|
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
|
||||||
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct
|
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
|
||||||
|
:: to run Llama-3.2-1B-Instruct
|
||||||
|
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
|
:: to run Llama-3.2-3B-Instruct
|
||||||
|
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
|
||||||
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
|
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
|
||||||
python qwen.py
|
python qwen.py
|
||||||
|
|
||||||
|
|
@ -145,6 +158,12 @@ python llama.py --disable-transpose-value-cache
|
||||||
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
|
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
|
||||||
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache
|
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache
|
||||||
|
|
||||||
|
:: to run Llama-3.2-1B-Instruct
|
||||||
|
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --disable-transpose-value-cache
|
||||||
|
|
||||||
|
:: to run Llama-3.2-3B-Instruct
|
||||||
|
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --disable-transpose-value-cache
|
||||||
|
|
||||||
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
|
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
|
||||||
python qwen.py --disable-transpose-value-cache
|
python qwen.py --disable-transpose-value-cache
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ This folder contains examples of running IPEX-LLM on Intel NPU:
|
||||||
|------------|----------------------------------------------------------------|
|
|------------|----------------------------------------------------------------|
|
||||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
|
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
|
||||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
|
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
|
||||||
|
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
|
||||||
|
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
|
||||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
||||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
||||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
||||||
|
|
|
||||||
|
|
@ -173,7 +173,8 @@ def convert_llama(
|
||||||
intra_pp=None,
|
intra_pp=None,
|
||||||
transpose_value_cache=True,
|
transpose_value_cache=True,
|
||||||
):
|
):
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
|
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward,\
|
||||||
|
gen_llama_32_fused_model_forward
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
|
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
|
||||||
|
|
@ -193,9 +194,18 @@ def convert_llama(
|
||||||
max_prompt_len=max_prompt_len,
|
max_prompt_len=max_prompt_len,
|
||||||
transpose_value_cache=transpose_value_cache,
|
transpose_value_cache=transpose_value_cache,
|
||||||
)
|
)
|
||||||
llama_model_forward = gen_llama_fused_model_forward(
|
from packaging import version
|
||||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
import transformers
|
||||||
)
|
trans_version = transformers.__version__
|
||||||
|
if version.parse(trans_version) == version.parse("4.45.0"):
|
||||||
|
# llama-3.2-3B & llama-3.2-1B
|
||||||
|
llama_model_forward = gen_llama_32_fused_model_forward(
|
||||||
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llama_model_forward = gen_llama_fused_model_forward(
|
||||||
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
|
)
|
||||||
convert_forward(model, LlamaModel, llama_model_forward)
|
convert_forward(model, LlamaModel, llama_model_forward)
|
||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
|
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ class DynamicFusedNormalCache(DynamicCache):
|
||||||
# Experimental support for fused decoderlayer implementation on NPU
|
# Experimental support for fused decoderlayer implementation on NPU
|
||||||
# Currently only for llama2
|
# Currently only for llama2
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||||
self.key_cache: Dict[int, torch.Tensor] = {}
|
self.key_cache: Dict[int, torch.Tensor] = {}
|
||||||
self.value_cache: Dict[int, torch.Tensor] = {}
|
self.value_cache: Dict[int, torch.Tensor] = {}
|
||||||
self.min_layer_idx = sys.maxsize
|
self.min_layer_idx = sys.maxsize
|
||||||
|
|
@ -158,6 +158,9 @@ class DynamicFusedNormalCache(DynamicCache):
|
||||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
if key_states == []:
|
||||||
|
return key_states, value_states
|
||||||
|
|
||||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||||
|
|
||||||
max_seq_length = cache_kwargs["max_seq_len"] if "max_seq_len" in cache_kwargs else None
|
max_seq_length = cache_kwargs["max_seq_len"] if "max_seq_len" in cache_kwargs else None
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
cos_len: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
|
|
@ -84,18 +85,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.cached_cos = cached_cos
|
self.cached_cos = cached_cos
|
||||||
self.cached_sin = cached_sin
|
self.cached_sin = cached_sin
|
||||||
|
self.cos_len = cos_len
|
||||||
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
|
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.rms_norm_eps = rms_norm_eps
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.transpose_value = transpose_value
|
self.transpose_value = transpose_value
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
cos = self.constant(self.cached_cos)
|
|
||||||
self.cos = self.unsqueeze(cos, axis=0)
|
|
||||||
|
|
||||||
sin = self.constant(self.cached_sin)
|
|
||||||
self.sin = self.unsqueeze(sin, axis=0)
|
|
||||||
|
|
||||||
if mode == "decode":
|
if mode == "decode":
|
||||||
self.kv_seq_len = self.max_seq_len + 1
|
self.kv_seq_len = self.max_seq_len + 1
|
||||||
else:
|
else:
|
||||||
|
|
@ -111,6 +107,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
||||||
|
|
||||||
# llama2/3 use ov sdp, other models need to test
|
# llama2/3 use ov sdp, other models need to test
|
||||||
|
|
||||||
use_prefill_sdp = self.intermediate_size in [11008, 14336]
|
use_prefill_sdp = self.intermediate_size in [11008, 14336]
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|
@ -124,8 +121,20 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
|
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
|
||||||
self.seq_len),
|
self.seq_len),
|
||||||
dtype=np.int64)
|
dtype=np.int64)
|
||||||
|
if self.cached_cos is None:
|
||||||
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
if mode == "prefill":
|
||||||
|
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||||
|
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||||
|
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||||
|
else:
|
||||||
|
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||||
|
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||||
|
else:
|
||||||
|
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||||
|
cos = self.constant(self.cached_cos)
|
||||||
|
self.cos = self.unsqueeze(cos, axis=0)
|
||||||
|
sin = self.constant(self.cached_sin)
|
||||||
|
self.sin = self.unsqueeze(sin, axis=0)
|
||||||
|
|
||||||
if input_layernorm_weights is None:
|
if input_layernorm_weights is None:
|
||||||
input_layernorm_weights = []
|
input_layernorm_weights = []
|
||||||
|
|
@ -179,12 +188,15 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
hidden_states, new_key_states, new_value_states = self.build_decoder(
|
hidden_states, new_key_states, new_value_states = self.build_decoder(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids if (cached_cos is not None
|
||||||
|
or mode == "prefill") else None,
|
||||||
input_layernorm_weight=input_layernorm_weights[i],
|
input_layernorm_weight=input_layernorm_weights[i],
|
||||||
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
||||||
past_key=past_keys[i],
|
past_key=past_keys[i],
|
||||||
past_value=past_values[i],
|
past_value=past_values[i],
|
||||||
use_prefill_sdp=use_prefill_sdp,
|
use_prefill_sdp=use_prefill_sdp,
|
||||||
|
cos=self.cos,
|
||||||
|
sin=self.sin,
|
||||||
)
|
)
|
||||||
curr_key_values.append((new_key_states, new_value_states))
|
curr_key_values.append((new_key_states, new_value_states))
|
||||||
|
|
||||||
|
|
@ -205,12 +217,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
|
||||||
input_layernorm_weight,
|
input_layernorm_weight,
|
||||||
post_attention_layernorm_weight,
|
post_attention_layernorm_weight,
|
||||||
|
position_ids=None,
|
||||||
past_key=None,
|
past_key=None,
|
||||||
past_value=None,
|
past_value=None,
|
||||||
use_prefill_sdp=False,
|
use_prefill_sdp=False,
|
||||||
|
cos=None,
|
||||||
|
sin=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
@ -222,8 +236,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
past_key=past_key,
|
past_key=past_key,
|
||||||
past_value=past_value,
|
past_value=past_value,
|
||||||
cos=self.cos,
|
cos=cos,
|
||||||
sin=self.sin,
|
sin=sin,
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
|
|
@ -282,6 +296,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
self.op_id = str(uuid.uuid4())
|
self.op_id = str(uuid.uuid4())
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.transpose_value = transpose_value
|
self.transpose_value = transpose_value
|
||||||
|
self.cached_cos = cached_cos
|
||||||
if isinstance(parameters[0], tuple):
|
if isinstance(parameters[0], tuple):
|
||||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||||
elif parameters[0].dtype == torch.int8:
|
elif parameters[0].dtype == torch.int8:
|
||||||
|
|
@ -341,15 +356,21 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
cos: Optional[torch.Tensor] = None,
|
||||||
|
sin: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
inputs = (
|
inputs = (
|
||||||
hidden_states.to(torch.float16),
|
hidden_states.to(torch.float16),
|
||||||
attention_mask.to(torch.int64),
|
attention_mask.to(torch.int64),
|
||||||
position_ids.to(torch.int64),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cached_cos is None:
|
||||||
|
inputs += (cos.to(torch.float16), sin.to(torch.float16))
|
||||||
|
else:
|
||||||
|
inputs += (position_ids.to(torch.int64),)
|
||||||
|
|
||||||
for i in range(self.intra_stages):
|
for i in range(self.intra_stages):
|
||||||
start, end = self.layer_ranges[i]
|
start, end = self.layer_ranges[i]
|
||||||
self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])
|
self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])
|
||||||
|
|
@ -402,7 +423,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
cos_len: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.op_parameters = parameters
|
self.op_parameters = parameters
|
||||||
|
|
@ -410,6 +432,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.transpose_value = transpose_value
|
self.transpose_value = transpose_value
|
||||||
|
self.cached_cos = cached_cos
|
||||||
# self.rotary_emb = rotary_emb
|
# self.rotary_emb = rotary_emb
|
||||||
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
|
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
|
||||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||||
|
|
@ -433,7 +456,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
cos_len=cos_len,
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
|
@ -448,6 +472,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
cos=None,
|
||||||
|
sin=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Torch module forward method.
|
"""Torch module forward method.
|
||||||
|
|
@ -469,6 +495,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
inputs = (hidden_states.to(torch.float16),
|
inputs = (hidden_states.to(torch.float16),
|
||||||
attention_mask.to(torch.int64),
|
attention_mask.to(torch.int64),
|
||||||
position_ids.to(torch.int64))
|
position_ids.to(torch.int64))
|
||||||
|
if self.cached_cos is None:
|
||||||
|
inputs += (cos.to(torch.float16), sin.to(torch.float16),)
|
||||||
inputs += (self.layer_norm_0, self.layer_norm_1)
|
inputs += (self.layer_norm_0, self.layer_norm_1)
|
||||||
hidden_states, past_key, past_value = run_model(
|
hidden_states, past_key, past_value = run_model(
|
||||||
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
|
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
|
||||||
|
|
@ -566,8 +594,12 @@ def run_decode(
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
else:
|
||||||
|
cached_cos = None
|
||||||
|
cached_sin = None
|
||||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||||
|
|
||||||
|
|
@ -599,9 +631,13 @@ def run_decode(
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
control = torch.empty((), dtype=torch.int)
|
control = torch.empty((), dtype=torch.int)
|
||||||
hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16)
|
hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16)
|
||||||
|
if cached_cos is None:
|
||||||
|
cos = torch.zeros((1, 1, head_dim), dtype=torch.float16)
|
||||||
|
sin = torch.zeros((1, 1, head_dim), dtype=torch.float16)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
|
|
@ -618,9 +654,15 @@ def run_decode(
|
||||||
)
|
)
|
||||||
|
|
||||||
position_ids = position_ids = cache_position.unsqueeze(0)
|
position_ids = position_ids = cache_position.unsqueeze(0)
|
||||||
causal_mask = model.model._update_causal_mask(
|
if cached_cos is None:
|
||||||
attention_mask, hidden_states, cache_position, past_seen_tokens
|
causal_mask = model.model._update_causal_mask(
|
||||||
)
|
attention_mask, hidden_states, cache_position,
|
||||||
|
past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
causal_mask = model.model._update_causal_mask(
|
||||||
|
attention_mask, hidden_states, cache_position, past_seen_tokens
|
||||||
|
)
|
||||||
pad_len = multi_decoder.max_seq_len + 1 - causal_mask.size(-1)
|
pad_len = multi_decoder.max_seq_len + 1 - causal_mask.size(-1)
|
||||||
|
|
||||||
pad_mask = (0, pad_len)
|
pad_mask = (0, pad_len)
|
||||||
|
|
@ -629,6 +671,9 @@ def run_decode(
|
||||||
)
|
)
|
||||||
padded_causal_mask[:, :, :, -1] = 0
|
padded_causal_mask[:, :, :, -1] = 0
|
||||||
dist.recv(hidden_states, src=rank - 1)
|
dist.recv(hidden_states, src=rank - 1)
|
||||||
|
if cached_cos is None:
|
||||||
|
dist.recv(cos, src=rank - 1)
|
||||||
|
dist.recv(sin, src=rank - 1)
|
||||||
layer_outputs = multi_decoder(
|
layer_outputs = multi_decoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=padded_causal_mask,
|
attention_mask=padded_causal_mask,
|
||||||
|
|
@ -637,9 +682,14 @@ def run_decode(
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
cos=cos if cached_cos is None else None,
|
||||||
|
sin=sin if cached_sin is None else None,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
dist.send(hidden_states, dst=(rank + 1) % world_size)
|
dist.send(hidden_states, dst=(rank + 1) % world_size)
|
||||||
|
if cached_cos is None:
|
||||||
|
dist.send(cos, dst=(rank + 1) % world_size)
|
||||||
|
dist.send(sin, dst=(rank + 1) % world_size)
|
||||||
past_key_values = layer_outputs[1]
|
past_key_values = layer_outputs[1]
|
||||||
new_keys = layer_outputs[2]
|
new_keys = layer_outputs[2]
|
||||||
new_values = layer_outputs[3]
|
new_values = layer_outputs[3]
|
||||||
|
|
@ -717,6 +767,8 @@ class DecodeRunner:
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
cos: Optional[torch.Tensor] = None,
|
||||||
|
sin: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -727,9 +779,18 @@ class DecodeRunner:
|
||||||
self.input_queues[i].put(past_key_value)
|
self.input_queues[i].put(past_key_value)
|
||||||
dist.broadcast(self.forward_signal, src=0, async_op=True)
|
dist.broadcast(self.forward_signal, src=0, async_op=True)
|
||||||
hidden_states = hidden_states.to(torch.float16)
|
hidden_states = hidden_states.to(torch.float16)
|
||||||
|
if cos is not None:
|
||||||
|
cos = cos.to(torch.float16)
|
||||||
|
sin = sin.to(torch.float16)
|
||||||
dist.send(hidden_states, dst=1)
|
dist.send(hidden_states, dst=1)
|
||||||
|
if cos is not None:
|
||||||
|
dist.send(cos, dst=1)
|
||||||
|
dist.send(sin, dst=1)
|
||||||
past_key_value.expand(self.transpose_value_cache)
|
past_key_value.expand(self.transpose_value_cache)
|
||||||
dist.recv(hidden_states, src=self.world_size - 1)
|
dist.recv(hidden_states, src=self.world_size - 1)
|
||||||
|
if cos is not None:
|
||||||
|
dist.recv(cos, src=self.world_size - 1)
|
||||||
|
dist.recv(sin, src=self.world_size - 1)
|
||||||
return hidden_states, past_key_value
|
return hidden_states, past_key_value
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
|
@ -749,103 +810,113 @@ def run_prefill(
|
||||||
model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue
|
model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue
|
||||||
):
|
):
|
||||||
|
|
||||||
layer_start = 0
|
deocderlayers = None
|
||||||
layer_end = len(model.model.layers)
|
|
||||||
num_heads = model.model.layers[layer_start].self_attn.num_heads
|
|
||||||
num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
|
|
||||||
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
|
||||||
intermediate_size = model.config.intermediate_size
|
|
||||||
group_size = getattr(model.config, "group_size", 0)
|
|
||||||
deocderlayers = []
|
|
||||||
layer_weights = []
|
|
||||||
input_layer_norm_weights = []
|
|
||||||
post_attn_layernorm_weights = []
|
|
||||||
layer_indexs = range(layer_start, layer_end)
|
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
|
||||||
for layer_idx in layer_indexs:
|
|
||||||
curr_layer = model.model.layers[layer_idx]
|
|
||||||
attn_layer = curr_layer.self_attn
|
|
||||||
mlp_layer = curr_layer.mlp
|
|
||||||
|
|
||||||
weights = []
|
|
||||||
|
|
||||||
if n_splits_linear == 1:
|
|
||||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
|
||||||
attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list,
|
|
||||||
attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list,
|
|
||||||
mlp_layer.up_proj_dq_list):
|
|
||||||
weights.append((q.weight, q.scale))
|
|
||||||
weights.append((k.weight, k.scale))
|
|
||||||
weights.append((v.weight, v.scale))
|
|
||||||
weights.append((o.weight, o.scale))
|
|
||||||
weights.append((g.weight, g.scale))
|
|
||||||
weights.append((u.weight, u.scale))
|
|
||||||
else:
|
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
if n_splits_down_proj == 1:
|
|
||||||
for l in mlp_layer.down_proj_dq_list:
|
|
||||||
weights.append((l.weight, l.scale))
|
|
||||||
else:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
for l in mlp_layer.down_proj_dq_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
|
||||||
|
|
||||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
|
||||||
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
|
||||||
|
|
||||||
new_decoderlayer = FusedLlamaLowBitDecoderlayer(
|
|
||||||
weights,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_key_value_heads=num_key_value_heads,
|
|
||||||
cached_cos=cached_cos,
|
|
||||||
cached_sin=cached_sin,
|
|
||||||
layer_norm_0=layer_norm_0,
|
|
||||||
layer_norm_1=layer_norm_1,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
rms_norm_eps=rms_norm_eps,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
max_seq_len=max_output_len,
|
|
||||||
transpose_value=transpose_value_cache,
|
|
||||||
n_splits_linear=n_splits_linear,
|
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
|
||||||
group_size=group_size
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights.extend(weights)
|
|
||||||
input_layer_norm_weights.append(layer_norm_0)
|
|
||||||
post_attn_layernorm_weights.append(layer_norm_1)
|
|
||||||
model.model.layers[layer_idx] = new_decoderlayer
|
|
||||||
deocderlayers.append(new_decoderlayer)
|
|
||||||
|
|
||||||
print("finish creating all decode layers in prefill")
|
|
||||||
result_queue.put("loading finish")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
result = input_queue.get()
|
result = input_queue.get()
|
||||||
if result == "stop":
|
if result == "stop":
|
||||||
break
|
break
|
||||||
|
|
||||||
hidden_states, position_ids, causal_mask, past_key_values, cache_position = result
|
hidden_states, position_ids, causal_mask, past_key_values, cache_position, cos, sin = result
|
||||||
|
|
||||||
|
if deocderlayers is None:
|
||||||
|
cos_len = cos.shape[1] if cos is not None else None
|
||||||
|
layer_start = 0
|
||||||
|
layer_end = len(model.model.layers)
|
||||||
|
num_heads = model.model.layers[layer_start].self_attn.num_heads
|
||||||
|
num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
|
||||||
|
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||||
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
|
intermediate_size = model.config.intermediate_size
|
||||||
|
group_size = getattr(model.config, "group_size", 0)
|
||||||
|
deocderlayers = []
|
||||||
|
layer_weights = []
|
||||||
|
input_layer_norm_weights = []
|
||||||
|
post_attn_layernorm_weights = []
|
||||||
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
|
for layer_idx in layer_indexs:
|
||||||
|
curr_layer = model.model.layers[layer_idx]
|
||||||
|
attn_layer = curr_layer.self_attn
|
||||||
|
mlp_layer = curr_layer.mlp
|
||||||
|
|
||||||
|
weights = []
|
||||||
|
|
||||||
|
if n_splits_linear == 1:
|
||||||
|
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||||
|
attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list,
|
||||||
|
attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list,
|
||||||
|
mlp_layer.up_proj_dq_list):
|
||||||
|
weights.append((q.weight, q.scale))
|
||||||
|
weights.append((k.weight, k.scale))
|
||||||
|
weights.append((v.weight, v.scale))
|
||||||
|
weights.append((o.weight, o.scale))
|
||||||
|
weights.append((g.weight, g.scale))
|
||||||
|
weights.append((u.weight, u.scale))
|
||||||
|
else:
|
||||||
|
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in layer_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0),
|
||||||
|
torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
|
if n_splits_down_proj == 1:
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
weights.append((l.weight, l.scale))
|
||||||
|
else:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
|
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||||
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
else:
|
||||||
|
cached_cos = None
|
||||||
|
cached_sin = None
|
||||||
|
|
||||||
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
|
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||||
|
|
||||||
|
new_decoderlayer = FusedLlamaLowBitDecoderlayer(
|
||||||
|
weights,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
cached_cos=cached_cos,
|
||||||
|
cached_sin=cached_sin,
|
||||||
|
layer_norm_0=layer_norm_0,
|
||||||
|
layer_norm_1=layer_norm_1,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
max_seq_len=max_output_len,
|
||||||
|
transpose_value=transpose_value_cache,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size,
|
||||||
|
cos_len=cos_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_weights.extend(weights)
|
||||||
|
input_layer_norm_weights.append(layer_norm_0)
|
||||||
|
post_attn_layernorm_weights.append(layer_norm_1)
|
||||||
|
model.model.layers[layer_idx] = new_decoderlayer
|
||||||
|
deocderlayers.append(new_decoderlayer)
|
||||||
|
|
||||||
|
print("finish creating all decode layers in prefill")
|
||||||
|
result_queue.put("loading finish")
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for decoder_layer in deocderlayers:
|
for decoder_layer in deocderlayers:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
|
|
@ -856,6 +927,8 @@ def run_prefill(
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
cos=cos,
|
||||||
|
sin=sin,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
@ -887,9 +960,6 @@ class PrefillRunner:
|
||||||
)
|
)
|
||||||
self.p.daemon = True
|
self.p.daemon = True
|
||||||
self.p.start()
|
self.p.start()
|
||||||
output = self.prefill_result_queue.get()
|
|
||||||
print(Fore.GREEN + f"prefill process output: {output}")
|
|
||||||
print(Style.RESET_ALL)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
@ -900,6 +970,8 @@ class PrefillRunner:
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
cos=None,
|
||||||
|
sin=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
seq_len = hidden_states.size(1)
|
seq_len = hidden_states.size(1)
|
||||||
|
|
@ -919,9 +991,16 @@ class PrefillRunner:
|
||||||
value=torch.iinfo(torch.int64).min,
|
value=torch.iinfo(torch.int64).min,
|
||||||
)
|
)
|
||||||
|
|
||||||
args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position)
|
args = (hidden_states, position_ids, attention_mask, past_key_value,
|
||||||
|
cache_position, cos, sin)
|
||||||
self.prefill_input_queue.put(args)
|
self.prefill_input_queue.put(args)
|
||||||
hidden_states, past_key_value = self.prefill_result_queue.get()
|
|
||||||
|
output = self.prefill_result_queue.get()
|
||||||
|
if output == "loading finish":
|
||||||
|
hidden_states, past_key_value = self.prefill_result_queue.get()
|
||||||
|
else:
|
||||||
|
hidden_states, past_key_value = output
|
||||||
|
|
||||||
past_key_value.shrink(seq_len, self.transpose_value_cache)
|
past_key_value.shrink(seq_len, self.transpose_value_cache)
|
||||||
hidden_states = hidden_states[:, :seq_len, :]
|
hidden_states = hidden_states[:, :seq_len, :]
|
||||||
return hidden_states, past_key_value
|
return hidden_states, past_key_value
|
||||||
|
|
@ -1051,6 +1130,125 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
|
||||||
return llama_fused_model_forward
|
return llama_fused_model_forward
|
||||||
|
|
||||||
|
|
||||||
|
def gen_llama_32_fused_model_forward(prefill_runner, decode_runner):
|
||||||
|
|
||||||
|
def llama_32_fused_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
msg = (
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time,"
|
||||||
|
" and must specify either one"
|
||||||
|
)
|
||||||
|
invalidInputError(False, msg)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# ipex-llm changes start
|
||||||
|
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
|
||||||
|
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
|
||||||
|
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() \
|
||||||
|
if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + inputs_embeds.shape[1],
|
||||||
|
device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
seq_len = hidden_states.size(1)
|
||||||
|
if seq_len == 1:
|
||||||
|
layers_runner = decode_runner
|
||||||
|
else:
|
||||||
|
layers_runner = prefill_runner
|
||||||
|
|
||||||
|
layer_outputs = layers_runner.forward(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
cos=cos,
|
||||||
|
sin=sin,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
next_decoder_cache = layer_outputs[1]
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
# ipex-llm changes start
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
# ipex-llm changes end
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
return llama_32_fused_model_forward
|
||||||
|
|
||||||
|
|
||||||
def llama2_casullm_forward(
|
def llama2_casullm_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|
|
||||||
|
|
@ -498,11 +498,12 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
|
|
||||||
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
|
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
|
||||||
num_heads, seq_len, head_dim):
|
num_heads, seq_len, head_dim):
|
||||||
position_ids = self.squeeze(position_ids)
|
if position_ids is not None:
|
||||||
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
|
position_ids = self.squeeze(position_ids)
|
||||||
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
|
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||||
cos = self.unsqueeze(cos, [1])
|
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||||
sin = self.unsqueeze(sin, [1])
|
cos = self.unsqueeze(cos, [1])
|
||||||
|
sin = self.unsqueeze(sin, [1])
|
||||||
|
|
||||||
rotate_half_q = self.rotate_half(q,
|
rotate_half_q = self.rotate_half(q,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue