From a7b66683f19961a6cc36739291a386ceabe1a250 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:21:40 +0800 Subject: [PATCH] [NPU] Add Optimized Support for Llama3.2-1B/3B on NPU (#12339) * Add initial support for llama3.2-1b/3b * move llama3.2 support into current llama_mp impl --- .../HF-Transformers-AutoModels/LLM/README.md | 19 + .../NPU/HF-Transformers-AutoModels/README.md | 2 + .../transformers/npu_models/convert_mp.py | 18 +- .../ipex_llm/transformers/npu_models/kv.py | 5 +- .../transformers/npu_models/llama_mp.py | 432 +++++++++++++----- .../transformers/npu_models/mp_models_base.py | 11 +- 6 files changed, 360 insertions(+), 127 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index ceae4e9a..54d37ae5 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -7,6 +7,8 @@ In this directory, you will find examples on how to directly run HuggingFace `tr |------------|----------------------------------------------------------------| | Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | | Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | +| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) | +| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | | Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) | | Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) | | Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | @@ -33,6 +35,9 @@ conda activate llm :: install ipex-llm with 'npu' option pip install --pre --upgrade ipex-llm[npu] + +:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct +pip install transformers==4.45.0 accelerate==0.33.0 ``` ## 2. Runtime Configurations @@ -82,6 +87,8 @@ done The examples below show how to run the **_optimized HuggingFace model implementations_** on Intel NPU, including - [Llama2-7B](./llama.py) - [Llama3-8B](./llama.py) +- [Llama3.2-1B](./llama.py) +- [Llama3.2-3B](./llama.py) - [Qwen2-1.5B](./qwen.py) - [Qwen2.5-7B](./qwen.py) - [MiniCPM-1B](./minicpm.py) @@ -106,6 +113,12 @@ python llama.py :: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct +:: to run Llama-3.2-1B-Instruct +python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct + +:: to run Llama-3.2-3B-Instruct +python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct + :: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715) python qwen.py @@ -145,6 +158,12 @@ python llama.py --disable-transpose-value-cache :: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache +:: to run Llama-3.2-1B-Instruct +python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --disable-transpose-value-cache + +:: to run Llama-3.2-3B-Instruct +python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --disable-transpose-value-cache + :: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715) python qwen.py --disable-transpose-value-cache diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/README.md index a92bff08..dc326c15 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/README.md @@ -10,6 +10,8 @@ This folder contains examples of running IPEX-LLM on Intel NPU: |------------|----------------------------------------------------------------| | Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | | Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | +| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) | +| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | | Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) | | Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) | | Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 4b8581e4..d00a0885 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -173,7 +173,8 @@ def convert_llama( intra_pp=None, transpose_value_cache=True, ): - from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward + from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward,\ + gen_llama_32_fused_model_forward from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner from transformers.models.llama.modeling_llama import LlamaModel @@ -193,9 +194,18 @@ def convert_llama( max_prompt_len=max_prompt_len, transpose_value_cache=transpose_value_cache, ) - llama_model_forward = gen_llama_fused_model_forward( - prefill_runner=prefill_runner, decode_runner=decode_runner - ) + from packaging import version + import transformers + trans_version = transformers.__version__ + if version.parse(trans_version) == version.parse("4.45.0"): + # llama-3.2-3B & llama-3.2-1B + llama_model_forward = gen_llama_32_fused_model_forward( + prefill_runner=prefill_runner, decode_runner=decode_runner + ) + else: + llama_model_forward = gen_llama_fused_model_forward( + prefill_runner=prefill_runner, decode_runner=decode_runner + ) convert_forward(model, LlamaModel, llama_model_forward) from transformers.models.llama.modeling_llama import LlamaForCausalLM from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward diff --git a/python/llm/src/ipex_llm/transformers/npu_models/kv.py b/python/llm/src/ipex_llm/transformers/npu_models/kv.py index 4f112a1c..56b03d4d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/kv.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/kv.py @@ -145,7 +145,7 @@ class DynamicFusedNormalCache(DynamicCache): # Experimental support for fused decoderlayer implementation on NPU # Currently only for llama2 - def __init__(self) -> None: + def __init__(self, num_hidden_layers: Optional[int] = None) -> None: self.key_cache: Dict[int, torch.Tensor] = {} self.value_cache: Dict[int, torch.Tensor] = {} self.min_layer_idx = sys.maxsize @@ -158,6 +158,9 @@ class DynamicFusedNormalCache(DynamicCache): cache_kwargs: Optional[Dict[str, Any]]=None, ) -> Tuple[torch.Tensor, torch.Tensor]: + if key_states == []: + return key_states, value_states + batch_size, num_heads, seq_len, head_dim = key_states.shape max_seq_length = cache_kwargs["max_seq_len"] if "max_seq_len" in cache_kwargs else None diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 8373aab7..f595d111 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -69,7 +69,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): intermediate_size, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + cos_len: int = 1, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -84,18 +85,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): self.dtype = dtype self.cached_cos = cached_cos self.cached_sin = cached_sin + self.cos_len = cos_len self.batch_size, self.seq_len, self.hidden_size = hidden_shape self.mode = mode self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value self.num_layers = num_layers - cos = self.constant(self.cached_cos) - self.cos = self.unsqueeze(cos, axis=0) - - sin = self.constant(self.cached_sin) - self.sin = self.unsqueeze(sin, axis=0) - if mode == "decode": self.kv_seq_len = self.max_seq_len + 1 else: @@ -111,6 +107,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) # llama2/3 use ov sdp, other models need to test + use_prefill_sdp = self.intermediate_size in [11008, 14336] # Self Attention @@ -124,8 +121,20 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.int64) - - position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) + if self.cached_cos is None: + if mode == "prefill": + position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) + self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) + self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) + else: + self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) + self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) + else: + position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) + cos = self.constant(self.cached_cos) + self.cos = self.unsqueeze(cos, axis=0) + sin = self.constant(self.cached_sin) + self.sin = self.unsqueeze(sin, axis=0) if input_layernorm_weights is None: input_layernorm_weights = [] @@ -179,12 +188,15 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): hidden_states, new_key_states, new_value_states = self.build_decoder( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids, + position_ids=position_ids if (cached_cos is not None + or mode == "prefill") else None, input_layernorm_weight=input_layernorm_weights[i], post_attention_layernorm_weight=post_attn_layernorm_weights[i], past_key=past_keys[i], past_value=past_values[i], use_prefill_sdp=use_prefill_sdp, + cos=self.cos, + sin=self.sin, ) curr_key_values.append((new_key_states, new_value_states)) @@ -205,12 +217,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): self, hidden_states, attention_mask, - position_ids, input_layernorm_weight, post_attention_layernorm_weight, + position_ids=None, past_key=None, past_value=None, use_prefill_sdp=False, + cos=None, + sin=None, ): residual = hidden_states @@ -222,8 +236,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): attention_mask=attention_mask, past_key=past_key, past_value=past_value, - cos=self.cos, - sin=self.sin, + cos=cos, + sin=sin, mode=self.mode, num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads, @@ -282,6 +296,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): self.op_id = str(uuid.uuid4()) self.max_seq_len = max_seq_len self.transpose_value = transpose_value + self.cached_cos = cached_cos if isinstance(parameters[0], tuple): np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 elif parameters[0].dtype == torch.int8: @@ -341,15 +356,21 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: inputs = ( hidden_states.to(torch.float16), attention_mask.to(torch.int64), - position_ids.to(torch.int64), ) + if self.cached_cos is None: + inputs += (cos.to(torch.float16), sin.to(torch.float16)) + else: + inputs += (position_ids.to(torch.int64),) + for i in range(self.intra_stages): start, end = self.layer_ranges[i] self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end]) @@ -402,7 +423,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): transpose_value: bool = False, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + cos_len: int = 1, ): super().__init__() self.op_parameters = parameters @@ -410,6 +432,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): self.layer_idx = layer_idx self.max_seq_len = max_seq_len self.transpose_value = transpose_value + self.cached_cos = cached_cos # self.rotary_emb = rotary_emb if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 @@ -433,7 +456,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + cos_len=cos_len, ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -448,6 +472,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos=None, + sin=None, **kwargs, ) -> torch.Tensor: """Torch module forward method. @@ -469,6 +495,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): inputs = (hidden_states.to(torch.float16), attention_mask.to(torch.int64), position_ids.to(torch.int64)) + if self.cached_cos is None: + inputs += (cos.to(torch.float16), sin.to(torch.float16),) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 @@ -566,8 +594,12 @@ def run_decode( scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) - cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + else: + cached_cos = None + cached_sin = None layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) @@ -599,9 +631,13 @@ def run_decode( dist.barrier() past_key_values = None + output_attentions = False control = torch.empty((), dtype=torch.int) hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16) + if cached_cos is None: + cos = torch.zeros((1, 1, head_dim), dtype=torch.float16) + sin = torch.zeros((1, 1, head_dim), dtype=torch.float16) with torch.inference_mode(): while True: @@ -618,9 +654,15 @@ def run_decode( ) position_ids = position_ids = cache_position.unsqueeze(0) - causal_mask = model.model._update_causal_mask( - attention_mask, hidden_states, cache_position, past_seen_tokens - ) + if cached_cos is None: + causal_mask = model.model._update_causal_mask( + attention_mask, hidden_states, cache_position, + past_key_values, output_attentions + ) + else: + causal_mask = model.model._update_causal_mask( + attention_mask, hidden_states, cache_position, past_seen_tokens + ) pad_len = multi_decoder.max_seq_len + 1 - causal_mask.size(-1) pad_mask = (0, pad_len) @@ -629,6 +671,9 @@ def run_decode( ) padded_causal_mask[:, :, :, -1] = 0 dist.recv(hidden_states, src=rank - 1) + if cached_cos is None: + dist.recv(cos, src=rank - 1) + dist.recv(sin, src=rank - 1) layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, @@ -637,9 +682,14 @@ def run_decode( output_attentions=False, use_cache=True, cache_position=cache_position, + cos=cos if cached_cos is None else None, + sin=sin if cached_sin is None else None, ) hidden_states = layer_outputs[0] dist.send(hidden_states, dst=(rank + 1) % world_size) + if cached_cos is None: + dist.send(cos, dst=(rank + 1) % world_size) + dist.send(sin, dst=(rank + 1) % world_size) past_key_values = layer_outputs[1] new_keys = layer_outputs[2] new_values = layer_outputs[3] @@ -717,6 +767,8 @@ class DecodeRunner: output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, **kwargs, ): @@ -727,9 +779,18 @@ class DecodeRunner: self.input_queues[i].put(past_key_value) dist.broadcast(self.forward_signal, src=0, async_op=True) hidden_states = hidden_states.to(torch.float16) + if cos is not None: + cos = cos.to(torch.float16) + sin = sin.to(torch.float16) dist.send(hidden_states, dst=1) + if cos is not None: + dist.send(cos, dst=1) + dist.send(sin, dst=1) past_key_value.expand(self.transpose_value_cache) dist.recv(hidden_states, src=self.world_size - 1) + if cos is not None: + dist.recv(cos, src=self.world_size - 1) + dist.recv(sin, src=self.world_size - 1) return hidden_states, past_key_value def shutdown(self): @@ -749,103 +810,113 @@ def run_prefill( model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue ): - layer_start = 0 - layer_end = len(model.model.layers) - num_heads = model.model.layers[layer_start].self_attn.num_heads - num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads - head_dim = model.model.layers[layer_start].self_attn.head_dim - rms_norm_eps = model.config.rms_norm_eps - intermediate_size = model.config.intermediate_size - group_size = getattr(model.config, "group_size", 0) - deocderlayers = [] - layer_weights = [] - input_layer_norm_weights = [] - post_attn_layernorm_weights = [] - layer_indexs = range(layer_start, layer_end) - n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) - n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) - for layer_idx in layer_indexs: - curr_layer = model.model.layers[layer_idx] - attn_layer = curr_layer.self_attn - mlp_layer = curr_layer.mlp - - weights = [] - - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: - l_weights = [] - scales = [] - for l in mlp_layer.down_proj_dq_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) - cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) - - layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) - layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) - - new_decoderlayer = FusedLlamaLowBitDecoderlayer( - weights, - num_heads=num_heads, - num_key_value_heads=num_key_value_heads, - cached_cos=cached_cos, - cached_sin=cached_sin, - layer_norm_0=layer_norm_0, - layer_norm_1=layer_norm_1, - layer_idx=layer_idx, - rms_norm_eps=rms_norm_eps, - intermediate_size=intermediate_size, - max_seq_len=max_output_len, - transpose_value=transpose_value_cache, - n_splits_linear=n_splits_linear, - n_splits_down_proj=n_splits_down_proj, - group_size=group_size - ) - - layer_weights.extend(weights) - input_layer_norm_weights.append(layer_norm_0) - post_attn_layernorm_weights.append(layer_norm_1) - model.model.layers[layer_idx] = new_decoderlayer - deocderlayers.append(new_decoderlayer) - - print("finish creating all decode layers in prefill") - result_queue.put("loading finish") + deocderlayers = None while True: - result = input_queue.get() if result == "stop": break - hidden_states, position_ids, causal_mask, past_key_values, cache_position = result + hidden_states, position_ids, causal_mask, past_key_values, cache_position, cos, sin = result + + if deocderlayers is None: + cos_len = cos.shape[1] if cos is not None else None + layer_start = 0 + layer_end = len(model.model.layers) + num_heads = model.model.layers[layer_start].self_attn.num_heads + num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads + head_dim = model.model.layers[layer_start].self_attn.head_dim + rms_norm_eps = model.config.rms_norm_eps + intermediate_size = model.config.intermediate_size + group_size = getattr(model.config, "group_size", 0) + deocderlayers = [] + layer_weights = [] + input_layer_norm_weights = [] + post_attn_layernorm_weights = [] + layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + for layer_idx in layer_indexs: + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [] + + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), + torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + else: + cached_cos = None + cached_sin = None + + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + new_decoderlayer = FusedLlamaLowBitDecoderlayer( + weights, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + cached_cos=cached_cos, + cached_sin=cached_sin, + layer_norm_0=layer_norm_0, + layer_norm_1=layer_norm_1, + layer_idx=layer_idx, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + max_seq_len=max_output_len, + transpose_value=transpose_value_cache, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size, + cos_len=cos_len, + ) + + layer_weights.extend(weights) + input_layer_norm_weights.append(layer_norm_0) + post_attn_layernorm_weights.append(layer_norm_1) + model.model.layers[layer_idx] = new_decoderlayer + deocderlayers.append(new_decoderlayer) + + print("finish creating all decode layers in prefill") + result_queue.put("loading finish") + with torch.inference_mode(): for decoder_layer in deocderlayers: layer_outputs = decoder_layer( @@ -856,6 +927,8 @@ def run_prefill( output_attentions=False, use_cache=True, cache_position=cache_position, + cos=cos, + sin=sin, ) hidden_states = layer_outputs[0] @@ -887,9 +960,6 @@ class PrefillRunner: ) self.p.daemon = True self.p.start() - output = self.prefill_result_queue.get() - print(Fore.GREEN + f"prefill process output: {output}") - print(Style.RESET_ALL) def forward( self, @@ -900,6 +970,8 @@ class PrefillRunner: output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos=None, + sin=None, **kwargs, ): seq_len = hidden_states.size(1) @@ -919,9 +991,16 @@ class PrefillRunner: value=torch.iinfo(torch.int64).min, ) - args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position) + args = (hidden_states, position_ids, attention_mask, past_key_value, + cache_position, cos, sin) self.prefill_input_queue.put(args) - hidden_states, past_key_value = self.prefill_result_queue.get() + + output = self.prefill_result_queue.get() + if output == "loading finish": + hidden_states, past_key_value = self.prefill_result_queue.get() + else: + hidden_states, past_key_value = output + past_key_value.shrink(seq_len, self.transpose_value_cache) hidden_states = hidden_states[:, :seq_len, :] return hidden_states, past_key_value @@ -1051,6 +1130,125 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner): return llama_fused_model_forward +def gen_llama_32_fused_model_forward(prefill_runner, decode_runner): + + def llama_32_fused_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + msg = ( + "You cannot specify both input_ids and inputs_embeds at the same time," + " and must specify either one" + ) + invalidInputError(False, msg) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # ipex-llm changes start + from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache + if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache): + past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() \ + if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device + ) + # ipex-llm changes end + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + cos, sin = position_embeddings + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + seq_len = hidden_states.size(1) + if seq_len == 1: + layers_runner = decode_runner + else: + layers_runner = prefill_runner + + layer_outputs = layers_runner.forward( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + cos=cos, + sin=sin, + ) + + hidden_states = layer_outputs[0] + next_decoder_cache = layer_outputs[1] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # ipex-llm changes start + next_cache = next_decoder_cache if use_cache else None + # ipex-llm changes end + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return llama_32_fused_model_forward + + def llama2_casullm_forward( self, input_ids: torch.LongTensor = None, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 5aa195a0..2997c0e8 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -498,11 +498,12 @@ class LLMBaseNNFactory(NNFactory): def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids, num_heads, seq_len, head_dim): - position_ids = self.squeeze(position_ids) - cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) - sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) - cos = self.unsqueeze(cos, [1]) - sin = self.unsqueeze(sin, [1]) + if position_ids is not None: + position_ids = self.squeeze(position_ids) + cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) + sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) rotate_half_q = self.rotate_half(q, num_heads=num_heads,