[NPU] Add Optimized Support for Llama3.2-1B/3B on NPU (#12339)
* Add initial support for llama3.2-1b/3b * move llama3.2 support into current llama_mp impl
This commit is contained in:
parent
872a74481a
commit
a7b66683f1
6 changed files with 360 additions and 127 deletions
|
|
@ -7,6 +7,8 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
|
|||
|------------|----------------------------------------------------------------|
|
||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
|
||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
|
||||
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
|
||||
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
|
||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
||||
|
|
@ -33,6 +35,9 @@ conda activate llm
|
|||
|
||||
:: install ipex-llm with 'npu' option
|
||||
pip install --pre --upgrade ipex-llm[npu]
|
||||
|
||||
:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct
|
||||
pip install transformers==4.45.0 accelerate==0.33.0
|
||||
```
|
||||
|
||||
## 2. Runtime Configurations
|
||||
|
|
@ -82,6 +87,8 @@ done
|
|||
The examples below show how to run the **_optimized HuggingFace model implementations_** on Intel NPU, including
|
||||
- [Llama2-7B](./llama.py)
|
||||
- [Llama3-8B](./llama.py)
|
||||
- [Llama3.2-1B](./llama.py)
|
||||
- [Llama3.2-3B](./llama.py)
|
||||
- [Qwen2-1.5B](./qwen.py)
|
||||
- [Qwen2.5-7B](./qwen.py)
|
||||
- [MiniCPM-1B](./minicpm.py)
|
||||
|
|
@ -106,6 +113,12 @@ python llama.py
|
|||
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
|
||||
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
:: to run Llama-3.2-1B-Instruct
|
||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
:: to run Llama-3.2-3B-Instruct
|
||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct
|
||||
|
||||
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
|
||||
python qwen.py
|
||||
|
||||
|
|
@ -145,6 +158,12 @@ python llama.py --disable-transpose-value-cache
|
|||
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
|
||||
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache
|
||||
|
||||
:: to run Llama-3.2-1B-Instruct
|
||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --disable-transpose-value-cache
|
||||
|
||||
:: to run Llama-3.2-3B-Instruct
|
||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --disable-transpose-value-cache
|
||||
|
||||
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
|
||||
python qwen.py --disable-transpose-value-cache
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ This folder contains examples of running IPEX-LLM on Intel NPU:
|
|||
|------------|----------------------------------------------------------------|
|
||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
|
||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
|
||||
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
|
||||
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
|
||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
||||
|
|
|
|||
|
|
@ -173,7 +173,8 @@ def convert_llama(
|
|||
intra_pp=None,
|
||||
transpose_value_cache=True,
|
||||
):
|
||||
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
|
||||
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward,\
|
||||
gen_llama_32_fused_model_forward
|
||||
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
|
|
@ -193,6 +194,15 @@ def convert_llama(
|
|||
max_prompt_len=max_prompt_len,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
)
|
||||
from packaging import version
|
||||
import transformers
|
||||
trans_version = transformers.__version__
|
||||
if version.parse(trans_version) == version.parse("4.45.0"):
|
||||
# llama-3.2-3B & llama-3.2-1B
|
||||
llama_model_forward = gen_llama_32_fused_model_forward(
|
||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||
)
|
||||
else:
|
||||
llama_model_forward = gen_llama_fused_model_forward(
|
||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||
)
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class DynamicFusedNormalCache(DynamicCache):
|
|||
# Experimental support for fused decoderlayer implementation on NPU
|
||||
# Currently only for llama2
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
self.key_cache: Dict[int, torch.Tensor] = {}
|
||||
self.value_cache: Dict[int, torch.Tensor] = {}
|
||||
self.min_layer_idx = sys.maxsize
|
||||
|
|
@ -158,6 +158,9 @@ class DynamicFusedNormalCache(DynamicCache):
|
|||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if key_states == []:
|
||||
return key_states, value_states
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||
|
||||
max_seq_length = cache_kwargs["max_seq_len"] if "max_seq_len" in cache_kwargs else None
|
||||
|
|
|
|||
|
|
@ -69,7 +69,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
intermediate_size,
|
||||
n_splits_linear: int = 1,
|
||||
n_splits_down_proj: int = 1,
|
||||
group_size: int = 0
|
||||
group_size: int = 0,
|
||||
cos_len: int = 1,
|
||||
):
|
||||
super().__init__(max_seq_len=max_seq_len,
|
||||
transpose_value=transpose_value,
|
||||
|
|
@ -84,18 +85,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
self.dtype = dtype
|
||||
self.cached_cos = cached_cos
|
||||
self.cached_sin = cached_sin
|
||||
self.cos_len = cos_len
|
||||
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
|
||||
self.mode = mode
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.transpose_value = transpose_value
|
||||
self.num_layers = num_layers
|
||||
|
||||
cos = self.constant(self.cached_cos)
|
||||
self.cos = self.unsqueeze(cos, axis=0)
|
||||
|
||||
sin = self.constant(self.cached_sin)
|
||||
self.sin = self.unsqueeze(sin, axis=0)
|
||||
|
||||
if mode == "decode":
|
||||
self.kv_seq_len = self.max_seq_len + 1
|
||||
else:
|
||||
|
|
@ -111,6 +107,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
||||
|
||||
# llama2/3 use ov sdp, other models need to test
|
||||
|
||||
use_prefill_sdp = self.intermediate_size in [11008, 14336]
|
||||
|
||||
# Self Attention
|
||||
|
|
@ -124,8 +121,20 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
|
||||
self.seq_len),
|
||||
dtype=np.int64)
|
||||
|
||||
if self.cached_cos is None:
|
||||
if mode == "prefill":
|
||||
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||
else:
|
||||
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
|
||||
else:
|
||||
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||
cos = self.constant(self.cached_cos)
|
||||
self.cos = self.unsqueeze(cos, axis=0)
|
||||
sin = self.constant(self.cached_sin)
|
||||
self.sin = self.unsqueeze(sin, axis=0)
|
||||
|
||||
if input_layernorm_weights is None:
|
||||
input_layernorm_weights = []
|
||||
|
|
@ -179,12 +188,15 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
hidden_states, new_key_states, new_value_states = self.build_decoder(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
position_ids=position_ids if (cached_cos is not None
|
||||
or mode == "prefill") else None,
|
||||
input_layernorm_weight=input_layernorm_weights[i],
|
||||
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
||||
past_key=past_keys[i],
|
||||
past_value=past_values[i],
|
||||
use_prefill_sdp=use_prefill_sdp,
|
||||
cos=self.cos,
|
||||
sin=self.sin,
|
||||
)
|
||||
curr_key_values.append((new_key_states, new_value_states))
|
||||
|
||||
|
|
@ -205,12 +217,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
input_layernorm_weight,
|
||||
post_attention_layernorm_weight,
|
||||
position_ids=None,
|
||||
past_key=None,
|
||||
past_value=None,
|
||||
use_prefill_sdp=False,
|
||||
cos=None,
|
||||
sin=None,
|
||||
):
|
||||
|
||||
residual = hidden_states
|
||||
|
|
@ -222,8 +236,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
attention_mask=attention_mask,
|
||||
past_key=past_key,
|
||||
past_value=past_value,
|
||||
cos=self.cos,
|
||||
sin=self.sin,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
mode=self.mode,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
|
|
@ -282,6 +296,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
self.op_id = str(uuid.uuid4())
|
||||
self.max_seq_len = max_seq_len
|
||||
self.transpose_value = transpose_value
|
||||
self.cached_cos = cached_cos
|
||||
if isinstance(parameters[0], tuple):
|
||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||
elif parameters[0].dtype == torch.int8:
|
||||
|
|
@ -341,15 +356,21 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cos: Optional[torch.Tensor] = None,
|
||||
sin: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs = (
|
||||
hidden_states.to(torch.float16),
|
||||
attention_mask.to(torch.int64),
|
||||
position_ids.to(torch.int64),
|
||||
)
|
||||
|
||||
if self.cached_cos is None:
|
||||
inputs += (cos.to(torch.float16), sin.to(torch.float16))
|
||||
else:
|
||||
inputs += (position_ids.to(torch.int64),)
|
||||
|
||||
for i in range(self.intra_stages):
|
||||
start, end = self.layer_ranges[i]
|
||||
self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])
|
||||
|
|
@ -402,7 +423,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
|||
transpose_value: bool = False,
|
||||
n_splits_linear: int = 1,
|
||||
n_splits_down_proj: int = 1,
|
||||
group_size: int = 0
|
||||
group_size: int = 0,
|
||||
cos_len: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.op_parameters = parameters
|
||||
|
|
@ -410,6 +432,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
|||
self.layer_idx = layer_idx
|
||||
self.max_seq_len = max_seq_len
|
||||
self.transpose_value = transpose_value
|
||||
self.cached_cos = cached_cos
|
||||
# self.rotary_emb = rotary_emb
|
||||
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
|
||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||
|
|
@ -433,7 +456,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
|||
dtype=np_dtype,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
group_size=group_size,
|
||||
cos_len=cos_len,
|
||||
)
|
||||
self.layer_norm_0 = layer_norm_0
|
||||
self.layer_norm_1 = layer_norm_1
|
||||
|
|
@ -448,6 +472,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cos=None,
|
||||
sin=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Torch module forward method.
|
||||
|
|
@ -469,6 +495,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
|||
inputs = (hidden_states.to(torch.float16),
|
||||
attention_mask.to(torch.int64),
|
||||
position_ids.to(torch.int64))
|
||||
if self.cached_cos is None:
|
||||
inputs += (cos.to(torch.float16), sin.to(torch.float16),)
|
||||
inputs += (self.layer_norm_0, self.layer_norm_1)
|
||||
hidden_states, past_key, past_value = run_model(
|
||||
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
|
||||
|
|
@ -566,8 +594,12 @@ def run_decode(
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
else:
|
||||
cached_cos = None
|
||||
cached_sin = None
|
||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||
|
||||
|
|
@ -599,9 +631,13 @@ def run_decode(
|
|||
dist.barrier()
|
||||
|
||||
past_key_values = None
|
||||
output_attentions = False
|
||||
|
||||
control = torch.empty((), dtype=torch.int)
|
||||
hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16)
|
||||
if cached_cos is None:
|
||||
cos = torch.zeros((1, 1, head_dim), dtype=torch.float16)
|
||||
sin = torch.zeros((1, 1, head_dim), dtype=torch.float16)
|
||||
with torch.inference_mode():
|
||||
while True:
|
||||
|
||||
|
|
@ -618,6 +654,12 @@ def run_decode(
|
|||
)
|
||||
|
||||
position_ids = position_ids = cache_position.unsqueeze(0)
|
||||
if cached_cos is None:
|
||||
causal_mask = model.model._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position,
|
||||
past_key_values, output_attentions
|
||||
)
|
||||
else:
|
||||
causal_mask = model.model._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position, past_seen_tokens
|
||||
)
|
||||
|
|
@ -629,6 +671,9 @@ def run_decode(
|
|||
)
|
||||
padded_causal_mask[:, :, :, -1] = 0
|
||||
dist.recv(hidden_states, src=rank - 1)
|
||||
if cached_cos is None:
|
||||
dist.recv(cos, src=rank - 1)
|
||||
dist.recv(sin, src=rank - 1)
|
||||
layer_outputs = multi_decoder(
|
||||
hidden_states,
|
||||
attention_mask=padded_causal_mask,
|
||||
|
|
@ -637,9 +682,14 @@ def run_decode(
|
|||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
cos=cos if cached_cos is None else None,
|
||||
sin=sin if cached_sin is None else None,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
dist.send(hidden_states, dst=(rank + 1) % world_size)
|
||||
if cached_cos is None:
|
||||
dist.send(cos, dst=(rank + 1) % world_size)
|
||||
dist.send(sin, dst=(rank + 1) % world_size)
|
||||
past_key_values = layer_outputs[1]
|
||||
new_keys = layer_outputs[2]
|
||||
new_values = layer_outputs[3]
|
||||
|
|
@ -717,6 +767,8 @@ class DecodeRunner:
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cos: Optional[torch.Tensor] = None,
|
||||
sin: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
|
|
@ -727,9 +779,18 @@ class DecodeRunner:
|
|||
self.input_queues[i].put(past_key_value)
|
||||
dist.broadcast(self.forward_signal, src=0, async_op=True)
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
if cos is not None:
|
||||
cos = cos.to(torch.float16)
|
||||
sin = sin.to(torch.float16)
|
||||
dist.send(hidden_states, dst=1)
|
||||
if cos is not None:
|
||||
dist.send(cos, dst=1)
|
||||
dist.send(sin, dst=1)
|
||||
past_key_value.expand(self.transpose_value_cache)
|
||||
dist.recv(hidden_states, src=self.world_size - 1)
|
||||
if cos is not None:
|
||||
dist.recv(cos, src=self.world_size - 1)
|
||||
dist.recv(sin, src=self.world_size - 1)
|
||||
return hidden_states, past_key_value
|
||||
|
||||
def shutdown(self):
|
||||
|
|
@ -749,6 +810,17 @@ def run_prefill(
|
|||
model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue
|
||||
):
|
||||
|
||||
deocderlayers = None
|
||||
|
||||
while True:
|
||||
result = input_queue.get()
|
||||
if result == "stop":
|
||||
break
|
||||
|
||||
hidden_states, position_ids, causal_mask, past_key_values, cache_position, cos, sin = result
|
||||
|
||||
if deocderlayers is None:
|
||||
cos_len = cos.shape[1] if cos is not None else None
|
||||
layer_start = 0
|
||||
layer_end = len(model.model.layers)
|
||||
num_heads = model.model.layers[layer_start].self_attn.num_heads
|
||||
|
|
@ -793,7 +865,8 @@ def run_prefill(
|
|||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
weights.append((torch.stack(l_weights, axis=0),
|
||||
torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
|
|
@ -806,8 +879,12 @@ def run_prefill(
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
else:
|
||||
cached_cos = None
|
||||
cached_sin = None
|
||||
|
||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||
|
|
@ -827,7 +904,8 @@ def run_prefill(
|
|||
transpose_value=transpose_value_cache,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
group_size=group_size,
|
||||
cos_len=cos_len,
|
||||
)
|
||||
|
||||
layer_weights.extend(weights)
|
||||
|
|
@ -839,13 +917,6 @@ def run_prefill(
|
|||
print("finish creating all decode layers in prefill")
|
||||
result_queue.put("loading finish")
|
||||
|
||||
while True:
|
||||
|
||||
result = input_queue.get()
|
||||
if result == "stop":
|
||||
break
|
||||
|
||||
hidden_states, position_ids, causal_mask, past_key_values, cache_position = result
|
||||
with torch.inference_mode():
|
||||
for decoder_layer in deocderlayers:
|
||||
layer_outputs = decoder_layer(
|
||||
|
|
@ -856,6 +927,8 @@ def run_prefill(
|
|||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
|
@ -887,9 +960,6 @@ class PrefillRunner:
|
|||
)
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
output = self.prefill_result_queue.get()
|
||||
print(Fore.GREEN + f"prefill process output: {output}")
|
||||
print(Style.RESET_ALL)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -900,6 +970,8 @@ class PrefillRunner:
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cos=None,
|
||||
sin=None,
|
||||
**kwargs,
|
||||
):
|
||||
seq_len = hidden_states.size(1)
|
||||
|
|
@ -919,9 +991,16 @@ class PrefillRunner:
|
|||
value=torch.iinfo(torch.int64).min,
|
||||
)
|
||||
|
||||
args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position)
|
||||
args = (hidden_states, position_ids, attention_mask, past_key_value,
|
||||
cache_position, cos, sin)
|
||||
self.prefill_input_queue.put(args)
|
||||
|
||||
output = self.prefill_result_queue.get()
|
||||
if output == "loading finish":
|
||||
hidden_states, past_key_value = self.prefill_result_queue.get()
|
||||
else:
|
||||
hidden_states, past_key_value = output
|
||||
|
||||
past_key_value.shrink(seq_len, self.transpose_value_cache)
|
||||
hidden_states = hidden_states[:, :seq_len, :]
|
||||
return hidden_states, past_key_value
|
||||
|
|
@ -1051,6 +1130,125 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
|
|||
return llama_fused_model_forward
|
||||
|
||||
|
||||
def gen_llama_32_fused_model_forward(prefill_runner, decode_runner):
|
||||
|
||||
def llama_32_fused_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
msg = (
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time,"
|
||||
" and must specify either one"
|
||||
)
|
||||
invalidInputError(False, msg)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# ipex-llm changes start
|
||||
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
|
||||
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
|
||||
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() \
|
||||
if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device
|
||||
)
|
||||
# ipex-llm changes end
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
seq_len = hidden_states.size(1)
|
||||
if seq_len == 1:
|
||||
layers_runner = decode_runner
|
||||
else:
|
||||
layers_runner = prefill_runner
|
||||
|
||||
layer_outputs = layers_runner.forward(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
next_decoder_cache = layer_outputs[1]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
# ipex-llm changes start
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
# ipex-llm changes end
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return llama_32_fused_model_forward
|
||||
|
||||
|
||||
def llama2_casullm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
|
|
|||
|
|
@ -498,6 +498,7 @@ class LLMBaseNNFactory(NNFactory):
|
|||
|
||||
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
|
||||
num_heads, seq_len, head_dim):
|
||||
if position_ids is not None:
|
||||
position_ids = self.squeeze(position_ids)
|
||||
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||
|
|
|
|||
Loading…
Reference in a new issue