[NPU] Add Optimized Support for Llama3.2-1B/3B on NPU (#12339)

* Add initial support for llama3.2-1b/3b

* move llama3.2 support into current llama_mp impl
This commit is contained in:
SONG Ge 2024-11-06 19:21:40 +08:00 committed by GitHub
parent 872a74481a
commit a7b66683f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 360 additions and 127 deletions

View file

@ -7,6 +7,8 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
|------------|----------------------------------------------------------------|
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
@ -33,6 +35,9 @@ conda activate llm
:: install ipex-llm with 'npu' option
pip install --pre --upgrade ipex-llm[npu]
:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct
pip install transformers==4.45.0 accelerate==0.33.0
```
## 2. Runtime Configurations
@ -82,6 +87,8 @@ done
The examples below show how to run the **_optimized HuggingFace model implementations_** on Intel NPU, including
- [Llama2-7B](./llama.py)
- [Llama3-8B](./llama.py)
- [Llama3.2-1B](./llama.py)
- [Llama3.2-3B](./llama.py)
- [Qwen2-1.5B](./qwen.py)
- [Qwen2.5-7B](./qwen.py)
- [MiniCPM-1B](./minicpm.py)
@ -106,6 +113,12 @@ python llama.py
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct
:: to run Llama-3.2-1B-Instruct
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct
:: to run Llama-3.2-3B-Instruct
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
python qwen.py
@ -145,6 +158,12 @@ python llama.py --disable-transpose-value-cache
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache
:: to run Llama-3.2-1B-Instruct
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --disable-transpose-value-cache
:: to run Llama-3.2-3B-Instruct
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --disable-transpose-value-cache
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
python qwen.py --disable-transpose-value-cache

View file

@ -10,6 +10,8 @@ This folder contains examples of running IPEX-LLM on Intel NPU:
|------------|----------------------------------------------------------------|
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |

View file

@ -173,7 +173,8 @@ def convert_llama(
intra_pp=None,
transpose_value_cache=True,
):
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward,\
gen_llama_32_fused_model_forward
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
from transformers.models.llama.modeling_llama import LlamaModel
@ -193,9 +194,18 @@ def convert_llama(
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)
llama_model_forward = gen_llama_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
from packaging import version
import transformers
trans_version = transformers.__version__
if version.parse(trans_version) == version.parse("4.45.0"):
# llama-3.2-3B & llama-3.2-1B
llama_model_forward = gen_llama_32_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
else:
llama_model_forward = gen_llama_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, LlamaModel, llama_model_forward)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward

View file

@ -145,7 +145,7 @@ class DynamicFusedNormalCache(DynamicCache):
# Experimental support for fused decoderlayer implementation on NPU
# Currently only for llama2
def __init__(self) -> None:
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
self.key_cache: Dict[int, torch.Tensor] = {}
self.value_cache: Dict[int, torch.Tensor] = {}
self.min_layer_idx = sys.maxsize
@ -158,6 +158,9 @@ class DynamicFusedNormalCache(DynamicCache):
cache_kwargs: Optional[Dict[str, Any]]=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states == []:
return key_states, value_states
batch_size, num_heads, seq_len, head_dim = key_states.shape
max_seq_length = cache_kwargs["max_seq_len"] if "max_seq_len" in cache_kwargs else None

View file

@ -69,7 +69,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
intermediate_size,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
group_size: int = 0,
cos_len: int = 1,
):
super().__init__(max_seq_len=max_seq_len,
transpose_value=transpose_value,
@ -84,18 +85,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
self.dtype = dtype
self.cached_cos = cached_cos
self.cached_sin = cached_sin
self.cos_len = cos_len
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
self.mode = mode
self.rms_norm_eps = rms_norm_eps
self.transpose_value = transpose_value
self.num_layers = num_layers
cos = self.constant(self.cached_cos)
self.cos = self.unsqueeze(cos, axis=0)
sin = self.constant(self.cached_sin)
self.sin = self.unsqueeze(sin, axis=0)
if mode == "decode":
self.kv_seq_len = self.max_seq_len + 1
else:
@ -111,6 +107,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
# llama2/3 use ov sdp, other models need to test
use_prefill_sdp = self.intermediate_size in [11008, 14336]
# Self Attention
@ -124,8 +121,20 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
self.seq_len),
dtype=np.int64)
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
if self.cached_cos is None:
if mode == "prefill":
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
else:
self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
else:
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
cos = self.constant(self.cached_cos)
self.cos = self.unsqueeze(cos, axis=0)
sin = self.constant(self.cached_sin)
self.sin = self.unsqueeze(sin, axis=0)
if input_layernorm_weights is None:
input_layernorm_weights = []
@ -179,12 +188,15 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
hidden_states, new_key_states, new_value_states = self.build_decoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_ids=position_ids if (cached_cos is not None
or mode == "prefill") else None,
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
cos=self.cos,
sin=self.sin,
)
curr_key_values.append((new_key_states, new_value_states))
@ -205,12 +217,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
self,
hidden_states,
attention_mask,
position_ids,
input_layernorm_weight,
post_attention_layernorm_weight,
position_ids=None,
past_key=None,
past_value=None,
use_prefill_sdp=False,
cos=None,
sin=None,
):
residual = hidden_states
@ -222,8 +236,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
attention_mask=attention_mask,
past_key=past_key,
past_value=past_value,
cos=self.cos,
sin=self.sin,
cos=cos,
sin=sin,
mode=self.mode,
num_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
@ -282,6 +296,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
self.op_id = str(uuid.uuid4())
self.max_seq_len = max_seq_len
self.transpose_value = transpose_value
self.cached_cos = cached_cos
if isinstance(parameters[0], tuple):
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
elif parameters[0].dtype == torch.int8:
@ -341,15 +356,21 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
inputs = (
hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64),
)
if self.cached_cos is None:
inputs += (cos.to(torch.float16), sin.to(torch.float16))
else:
inputs += (position_ids.to(torch.int64),)
for i in range(self.intra_stages):
start, end = self.layer_ranges[i]
self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])
@ -402,7 +423,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
transpose_value: bool = False,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
group_size: int = 0,
cos_len: int = 1,
):
super().__init__()
self.op_parameters = parameters
@ -410,6 +432,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
self.layer_idx = layer_idx
self.max_seq_len = max_seq_len
self.transpose_value = transpose_value
self.cached_cos = cached_cos
# self.rotary_emb = rotary_emb
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
@ -433,7 +456,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
group_size=group_size,
cos_len=cos_len,
)
self.layer_norm_0 = layer_norm_0
self.layer_norm_1 = layer_norm_1
@ -448,6 +472,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cos=None,
sin=None,
**kwargs,
) -> torch.Tensor:
"""Torch module forward method.
@ -469,6 +495,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
if self.cached_cos is None:
inputs += (cos.to(torch.float16), sin.to(torch.float16),)
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
@ -566,8 +594,12 @@ def run_decode(
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
else:
cached_cos = None
cached_sin = None
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
@ -599,9 +631,13 @@ def run_decode(
dist.barrier()
past_key_values = None
output_attentions = False
control = torch.empty((), dtype=torch.int)
hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16)
if cached_cos is None:
cos = torch.zeros((1, 1, head_dim), dtype=torch.float16)
sin = torch.zeros((1, 1, head_dim), dtype=torch.float16)
with torch.inference_mode():
while True:
@ -618,9 +654,15 @@ def run_decode(
)
position_ids = position_ids = cache_position.unsqueeze(0)
causal_mask = model.model._update_causal_mask(
attention_mask, hidden_states, cache_position, past_seen_tokens
)
if cached_cos is None:
causal_mask = model.model._update_causal_mask(
attention_mask, hidden_states, cache_position,
past_key_values, output_attentions
)
else:
causal_mask = model.model._update_causal_mask(
attention_mask, hidden_states, cache_position, past_seen_tokens
)
pad_len = multi_decoder.max_seq_len + 1 - causal_mask.size(-1)
pad_mask = (0, pad_len)
@ -629,6 +671,9 @@ def run_decode(
)
padded_causal_mask[:, :, :, -1] = 0
dist.recv(hidden_states, src=rank - 1)
if cached_cos is None:
dist.recv(cos, src=rank - 1)
dist.recv(sin, src=rank - 1)
layer_outputs = multi_decoder(
hidden_states,
attention_mask=padded_causal_mask,
@ -637,9 +682,14 @@ def run_decode(
output_attentions=False,
use_cache=True,
cache_position=cache_position,
cos=cos if cached_cos is None else None,
sin=sin if cached_sin is None else None,
)
hidden_states = layer_outputs[0]
dist.send(hidden_states, dst=(rank + 1) % world_size)
if cached_cos is None:
dist.send(cos, dst=(rank + 1) % world_size)
dist.send(sin, dst=(rank + 1) % world_size)
past_key_values = layer_outputs[1]
new_keys = layer_outputs[2]
new_values = layer_outputs[3]
@ -717,6 +767,8 @@ class DecodeRunner:
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
**kwargs,
):
@ -727,9 +779,18 @@ class DecodeRunner:
self.input_queues[i].put(past_key_value)
dist.broadcast(self.forward_signal, src=0, async_op=True)
hidden_states = hidden_states.to(torch.float16)
if cos is not None:
cos = cos.to(torch.float16)
sin = sin.to(torch.float16)
dist.send(hidden_states, dst=1)
if cos is not None:
dist.send(cos, dst=1)
dist.send(sin, dst=1)
past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1)
if cos is not None:
dist.recv(cos, src=self.world_size - 1)
dist.recv(sin, src=self.world_size - 1)
return hidden_states, past_key_value
def shutdown(self):
@ -749,103 +810,113 @@ def run_prefill(
model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue
):
layer_start = 0
layer_end = len(model.model.layers)
num_heads = model.model.layers[layer_start].self_attn.num_heads
num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size
group_size = getattr(model.config, "group_size", 0)
deocderlayers = []
layer_weights = []
input_layer_norm_weights = []
post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp
weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
new_decoderlayer = FusedLlamaLowBitDecoderlayer(
weights,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
cached_cos=cached_cos,
cached_sin=cached_sin,
layer_norm_0=layer_norm_0,
layer_norm_1=layer_norm_1,
layer_idx=layer_idx,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
max_seq_len=max_output_len,
transpose_value=transpose_value_cache,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
)
layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1)
model.model.layers[layer_idx] = new_decoderlayer
deocderlayers.append(new_decoderlayer)
print("finish creating all decode layers in prefill")
result_queue.put("loading finish")
deocderlayers = None
while True:
result = input_queue.get()
if result == "stop":
break
hidden_states, position_ids, causal_mask, past_key_values, cache_position = result
hidden_states, position_ids, causal_mask, past_key_values, cache_position, cos, sin = result
if deocderlayers is None:
cos_len = cos.shape[1] if cos is not None else None
layer_start = 0
layer_end = len(model.model.layers)
num_heads = model.model.layers[layer_start].self_attn.num_heads
num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size
group_size = getattr(model.config, "group_size", 0)
deocderlayers = []
layer_weights = []
input_layer_norm_weights = []
post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp
weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0),
torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
else:
cached_cos = None
cached_sin = None
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
new_decoderlayer = FusedLlamaLowBitDecoderlayer(
weights,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
cached_cos=cached_cos,
cached_sin=cached_sin,
layer_norm_0=layer_norm_0,
layer_norm_1=layer_norm_1,
layer_idx=layer_idx,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
max_seq_len=max_output_len,
transpose_value=transpose_value_cache,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size,
cos_len=cos_len,
)
layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1)
model.model.layers[layer_idx] = new_decoderlayer
deocderlayers.append(new_decoderlayer)
print("finish creating all decode layers in prefill")
result_queue.put("loading finish")
with torch.inference_mode():
for decoder_layer in deocderlayers:
layer_outputs = decoder_layer(
@ -856,6 +927,8 @@ def run_prefill(
output_attentions=False,
use_cache=True,
cache_position=cache_position,
cos=cos,
sin=sin,
)
hidden_states = layer_outputs[0]
@ -887,9 +960,6 @@ class PrefillRunner:
)
self.p.daemon = True
self.p.start()
output = self.prefill_result_queue.get()
print(Fore.GREEN + f"prefill process output: {output}")
print(Style.RESET_ALL)
def forward(
self,
@ -900,6 +970,8 @@ class PrefillRunner:
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cos=None,
sin=None,
**kwargs,
):
seq_len = hidden_states.size(1)
@ -919,9 +991,16 @@ class PrefillRunner:
value=torch.iinfo(torch.int64).min,
)
args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position)
args = (hidden_states, position_ids, attention_mask, past_key_value,
cache_position, cos, sin)
self.prefill_input_queue.put(args)
hidden_states, past_key_value = self.prefill_result_queue.get()
output = self.prefill_result_queue.get()
if output == "loading finish":
hidden_states, past_key_value = self.prefill_result_queue.get()
else:
hidden_states, past_key_value = output
past_key_value.shrink(seq_len, self.transpose_value_cache)
hidden_states = hidden_states[:, :seq_len, :]
return hidden_states, past_key_value
@ -1051,6 +1130,125 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
return llama_fused_model_forward
def gen_llama_32_fused_model_forward(prefill_runner, decode_runner):
def llama_32_fused_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
msg = (
"You cannot specify both input_ids and inputs_embeds at the same time,"
" and must specify either one"
)
invalidInputError(False, msg)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# ipex-llm changes start
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() \
if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device
)
# ipex-llm changes end
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
cos, sin = position_embeddings
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
seq_len = hidden_states.size(1)
if seq_len == 1:
layers_runner = decode_runner
else:
layers_runner = prefill_runner
layer_outputs = layers_runner.forward(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cos=cos,
sin=sin,
)
hidden_states = layer_outputs[0]
next_decoder_cache = layer_outputs[1]
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
# ipex-llm changes start
next_cache = next_decoder_cache if use_cache else None
# ipex-llm changes end
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return llama_32_fused_model_forward
def llama2_casullm_forward(
self,
input_ids: torch.LongTensor = None,

View file

@ -498,11 +498,12 @@ class LLMBaseNNFactory(NNFactory):
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
num_heads, seq_len, head_dim):
position_ids = self.squeeze(position_ids)
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
cos = self.unsqueeze(cos, [1])
sin = self.unsqueeze(sin, [1])
if position_ids is not None:
position_ids = self.squeeze(position_ids)
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
cos = self.unsqueeze(cos, [1])
sin = self.unsqueeze(sin, [1])
rotate_half_q = self.rotate_half(q,
num_heads=num_heads,