[NPU] Add Optimized Support for Llama3.2-1B/3B on NPU (#12339)
* Add initial support for llama3.2-1b/3b * move llama3.2 support into current llama_mp impl
This commit is contained in:
		
							parent
							
								
									872a74481a
								
							
						
					
					
						commit
						a7b66683f1
					
				
					 6 changed files with 360 additions and 127 deletions
				
			
		| 
						 | 
				
			
			@ -7,6 +7,8 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
 | 
			
		|||
|------------|----------------------------------------------------------------|
 | 
			
		||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
 | 
			
		||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
 | 
			
		||||
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
 | 
			
		||||
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
 | 
			
		||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
 | 
			
		||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
 | 
			
		||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
			
		||||
| 
						 | 
				
			
			@ -33,6 +35,9 @@ conda activate llm
 | 
			
		|||
 | 
			
		||||
:: install ipex-llm with 'npu' option
 | 
			
		||||
pip install --pre --upgrade ipex-llm[npu]
 | 
			
		||||
 | 
			
		||||
:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct
 | 
			
		||||
pip install transformers==4.45.0 accelerate==0.33.0
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 2. Runtime Configurations
 | 
			
		||||
| 
						 | 
				
			
			@ -82,6 +87,8 @@ done
 | 
			
		|||
The examples below show how to run the **_optimized HuggingFace model implementations_** on Intel NPU, including
 | 
			
		||||
- [Llama2-7B](./llama.py)
 | 
			
		||||
- [Llama3-8B](./llama.py)
 | 
			
		||||
- [Llama3.2-1B](./llama.py)
 | 
			
		||||
- [Llama3.2-3B](./llama.py)
 | 
			
		||||
- [Qwen2-1.5B](./qwen.py)
 | 
			
		||||
- [Qwen2.5-7B](./qwen.py)
 | 
			
		||||
- [MiniCPM-1B](./minicpm.py)
 | 
			
		||||
| 
						 | 
				
			
			@ -106,6 +113,12 @@ python llama.py
 | 
			
		|||
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
 | 
			
		||||
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
 | 
			
		||||
:: to run Llama-3.2-1B-Instruct
 | 
			
		||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct
 | 
			
		||||
 | 
			
		||||
:: to run Llama-3.2-3B-Instruct
 | 
			
		||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct
 | 
			
		||||
 | 
			
		||||
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
 | 
			
		||||
python qwen.py
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -145,6 +158,12 @@ python llama.py --disable-transpose-value-cache
 | 
			
		|||
:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715)
 | 
			
		||||
python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache
 | 
			
		||||
 | 
			
		||||
:: to run Llama-3.2-1B-Instruct
 | 
			
		||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --disable-transpose-value-cache
 | 
			
		||||
 | 
			
		||||
:: to run Llama-3.2-3B-Instruct
 | 
			
		||||
python llama.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --disable-transpose-value-cache
 | 
			
		||||
 | 
			
		||||
:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715)
 | 
			
		||||
python qwen.py --disable-transpose-value-cache
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,6 +10,8 @@ This folder contains examples of running IPEX-LLM on Intel NPU:
 | 
			
		|||
|------------|----------------------------------------------------------------|
 | 
			
		||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
 | 
			
		||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
 | 
			
		||||
| Llama3.2-1B | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) |
 | 
			
		||||
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
 | 
			
		||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
 | 
			
		||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
 | 
			
		||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -173,7 +173,8 @@ def convert_llama(
 | 
			
		|||
        intra_pp=None,
 | 
			
		||||
        transpose_value_cache=True,
 | 
			
		||||
):
 | 
			
		||||
    from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
 | 
			
		||||
    from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward,\
 | 
			
		||||
        gen_llama_32_fused_model_forward
 | 
			
		||||
    from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
 | 
			
		||||
    from transformers.models.llama.modeling_llama import LlamaModel
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -193,9 +194,18 @@ def convert_llama(
 | 
			
		|||
        max_prompt_len=max_prompt_len,
 | 
			
		||||
        transpose_value_cache=transpose_value_cache,
 | 
			
		||||
    )
 | 
			
		||||
    llama_model_forward = gen_llama_fused_model_forward(
 | 
			
		||||
        prefill_runner=prefill_runner, decode_runner=decode_runner
 | 
			
		||||
    )
 | 
			
		||||
    from packaging import version
 | 
			
		||||
    import transformers
 | 
			
		||||
    trans_version = transformers.__version__
 | 
			
		||||
    if version.parse(trans_version) == version.parse("4.45.0"):
 | 
			
		||||
        # llama-3.2-3B & llama-3.2-1B
 | 
			
		||||
        llama_model_forward = gen_llama_32_fused_model_forward(
 | 
			
		||||
            prefill_runner=prefill_runner, decode_runner=decode_runner
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        llama_model_forward = gen_llama_fused_model_forward(
 | 
			
		||||
            prefill_runner=prefill_runner, decode_runner=decode_runner
 | 
			
		||||
        )
 | 
			
		||||
    convert_forward(model, LlamaModel, llama_model_forward)
 | 
			
		||||
    from transformers.models.llama.modeling_llama import LlamaForCausalLM
 | 
			
		||||
    from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -145,7 +145,7 @@ class DynamicFusedNormalCache(DynamicCache):
 | 
			
		|||
    # Experimental support for fused decoderlayer implementation on NPU
 | 
			
		||||
    # Currently only for llama2
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
    def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
 | 
			
		||||
        self.key_cache: Dict[int, torch.Tensor] = {}
 | 
			
		||||
        self.value_cache: Dict[int, torch.Tensor] = {}
 | 
			
		||||
        self.min_layer_idx = sys.maxsize
 | 
			
		||||
| 
						 | 
				
			
			@ -158,6 +158,9 @@ class DynamicFusedNormalCache(DynamicCache):
 | 
			
		|||
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
        if key_states == []:
 | 
			
		||||
            return key_states, value_states
 | 
			
		||||
 | 
			
		||||
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
			
		||||
 | 
			
		||||
        max_seq_length = cache_kwargs["max_seq_len"] if "max_seq_len" in cache_kwargs else None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -69,7 +69,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        intermediate_size,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        cos_len: int = 1,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -84,18 +85,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        self.dtype = dtype
 | 
			
		||||
        self.cached_cos = cached_cos
 | 
			
		||||
        self.cached_sin = cached_sin
 | 
			
		||||
        self.cos_len = cos_len
 | 
			
		||||
        self.batch_size, self.seq_len, self.hidden_size = hidden_shape
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.rms_norm_eps = rms_norm_eps
 | 
			
		||||
        self.transpose_value = transpose_value
 | 
			
		||||
        self.num_layers = num_layers
 | 
			
		||||
 | 
			
		||||
        cos = self.constant(self.cached_cos)
 | 
			
		||||
        self.cos = self.unsqueeze(cos, axis=0)
 | 
			
		||||
 | 
			
		||||
        sin = self.constant(self.cached_sin)
 | 
			
		||||
        self.sin = self.unsqueeze(sin, axis=0)
 | 
			
		||||
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            self.kv_seq_len = self.max_seq_len + 1
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +107,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
 | 
			
		||||
 | 
			
		||||
        # llama2/3 use ov sdp, other models need to test
 | 
			
		||||
 | 
			
		||||
        use_prefill_sdp = self.intermediate_size in [11008, 14336]
 | 
			
		||||
 | 
			
		||||
        # Self Attention
 | 
			
		||||
| 
						 | 
				
			
			@ -124,8 +121,20 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
 | 
			
		||||
                                                       self.seq_len),
 | 
			
		||||
                                                      dtype=np.int64)
 | 
			
		||||
 | 
			
		||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
        if self.cached_cos is None:
 | 
			
		||||
            if mode == "prefill":
 | 
			
		||||
                position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
                self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
			
		||||
                self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
			
		||||
            else:
 | 
			
		||||
                self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
			
		||||
                self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
			
		||||
        else:
 | 
			
		||||
            position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
            cos = self.constant(self.cached_cos)
 | 
			
		||||
            self.cos = self.unsqueeze(cos, axis=0)
 | 
			
		||||
            sin = self.constant(self.cached_sin)
 | 
			
		||||
            self.sin = self.unsqueeze(sin, axis=0)
 | 
			
		||||
 | 
			
		||||
        if input_layernorm_weights is None:
 | 
			
		||||
            input_layernorm_weights = []
 | 
			
		||||
| 
						 | 
				
			
			@ -179,12 +188,15 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            hidden_states, new_key_states, new_value_states = self.build_decoder(
 | 
			
		||||
                hidden_states=hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                position_ids=position_ids if (cached_cos is not None
 | 
			
		||||
                                              or mode == "prefill") else None,
 | 
			
		||||
                input_layernorm_weight=input_layernorm_weights[i],
 | 
			
		||||
                post_attention_layernorm_weight=post_attn_layernorm_weights[i],
 | 
			
		||||
                past_key=past_keys[i],
 | 
			
		||||
                past_value=past_values[i],
 | 
			
		||||
                use_prefill_sdp=use_prefill_sdp,
 | 
			
		||||
                cos=self.cos,
 | 
			
		||||
                sin=self.sin,
 | 
			
		||||
            )
 | 
			
		||||
            curr_key_values.append((new_key_states, new_value_states))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -205,12 +217,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        self,
 | 
			
		||||
        hidden_states,
 | 
			
		||||
        attention_mask,
 | 
			
		||||
        position_ids,
 | 
			
		||||
        input_layernorm_weight,
 | 
			
		||||
        post_attention_layernorm_weight,
 | 
			
		||||
        position_ids=None,
 | 
			
		||||
        past_key=None,
 | 
			
		||||
        past_value=None,
 | 
			
		||||
        use_prefill_sdp=False,
 | 
			
		||||
        cos=None,
 | 
			
		||||
        sin=None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
| 
						 | 
				
			
			@ -222,8 +236,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key=past_key,
 | 
			
		||||
            past_value=past_value,
 | 
			
		||||
            cos=self.cos,
 | 
			
		||||
            sin=self.sin,
 | 
			
		||||
            cos=cos,
 | 
			
		||||
            sin=sin,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
            num_heads=self.num_heads,
 | 
			
		||||
            num_key_value_heads=self.num_key_value_heads,
 | 
			
		||||
| 
						 | 
				
			
			@ -282,6 +296,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        self.op_id = str(uuid.uuid4())
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.transpose_value = transpose_value
 | 
			
		||||
        self.cached_cos = cached_cos
 | 
			
		||||
        if isinstance(parameters[0], tuple):
 | 
			
		||||
            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
        elif parameters[0].dtype == torch.int8:
 | 
			
		||||
| 
						 | 
				
			
			@ -341,15 +356,21 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        output_attentions: bool = False,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
        cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
        cos: Optional[torch.Tensor] = None,
 | 
			
		||||
        sin: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            position_ids.to(torch.int64),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.cached_cos is None:
 | 
			
		||||
            inputs += (cos.to(torch.float16), sin.to(torch.float16))
 | 
			
		||||
        else:
 | 
			
		||||
            inputs += (position_ids.to(torch.int64),)
 | 
			
		||||
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
            start, end = self.layer_ranges[i]
 | 
			
		||||
            self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])
 | 
			
		||||
| 
						 | 
				
			
			@ -402,7 +423,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        transpose_value: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        cos_len: int = 1,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.op_parameters = parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -410,6 +432,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        self.layer_idx = layer_idx
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.transpose_value = transpose_value
 | 
			
		||||
        self.cached_cos = cached_cos
 | 
			
		||||
        # self.rotary_emb = rotary_emb
 | 
			
		||||
        if isinstance(parameters[0], tuple):  # weight, scale from QuantizedLinear
 | 
			
		||||
            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
| 
						 | 
				
			
			@ -433,7 +456,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            cos_len=cos_len,
 | 
			
		||||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
| 
						 | 
				
			
			@ -448,6 +472,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        output_attentions: bool = False,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
        cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
        cos=None,
 | 
			
		||||
        sin=None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """Torch module forward method.
 | 
			
		||||
| 
						 | 
				
			
			@ -469,6 +495,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                      attention_mask.to(torch.int64),
 | 
			
		||||
                      position_ids.to(torch.int64))
 | 
			
		||||
        if self.cached_cos is None:
 | 
			
		||||
            inputs += (cos.to(torch.float16), sin.to(torch.float16),)
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
            inputs, self.op_parameters, backend_cls, self.op_id, replica=2
 | 
			
		||||
| 
						 | 
				
			
			@ -566,8 +594,12 @@ def run_decode(
 | 
			
		|||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
        if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
            cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
        else:
 | 
			
		||||
            cached_cos = None
 | 
			
		||||
            cached_sin = None
 | 
			
		||||
        layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
			
		||||
        layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -599,9 +631,13 @@ def run_decode(
 | 
			
		|||
    dist.barrier()
 | 
			
		||||
 | 
			
		||||
    past_key_values = None
 | 
			
		||||
    output_attentions = False
 | 
			
		||||
 | 
			
		||||
    control = torch.empty((), dtype=torch.int)
 | 
			
		||||
    hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16)
 | 
			
		||||
    if cached_cos is None:
 | 
			
		||||
        cos = torch.zeros((1, 1, head_dim), dtype=torch.float16)
 | 
			
		||||
        sin = torch.zeros((1, 1, head_dim), dtype=torch.float16)
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        while True:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -618,9 +654,15 @@ def run_decode(
 | 
			
		|||
                )
 | 
			
		||||
 | 
			
		||||
                position_ids = position_ids = cache_position.unsqueeze(0)
 | 
			
		||||
                causal_mask = model.model._update_causal_mask(
 | 
			
		||||
                    attention_mask, hidden_states, cache_position, past_seen_tokens
 | 
			
		||||
                )
 | 
			
		||||
                if cached_cos is None:
 | 
			
		||||
                    causal_mask = model.model._update_causal_mask(
 | 
			
		||||
                        attention_mask, hidden_states, cache_position,
 | 
			
		||||
                        past_key_values, output_attentions
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    causal_mask = model.model._update_causal_mask(
 | 
			
		||||
                        attention_mask, hidden_states, cache_position, past_seen_tokens
 | 
			
		||||
                    )
 | 
			
		||||
                pad_len = multi_decoder.max_seq_len + 1 - causal_mask.size(-1)
 | 
			
		||||
 | 
			
		||||
                pad_mask = (0, pad_len)
 | 
			
		||||
| 
						 | 
				
			
			@ -629,6 +671,9 @@ def run_decode(
 | 
			
		|||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                if cached_cos is None:
 | 
			
		||||
                    dist.recv(cos, src=rank - 1)
 | 
			
		||||
                    dist.recv(sin, src=rank - 1)
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=padded_causal_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -637,9 +682,14 @@ def run_decode(
 | 
			
		|||
                    output_attentions=False,
 | 
			
		||||
                    use_cache=True,
 | 
			
		||||
                    cache_position=cache_position,
 | 
			
		||||
                    cos=cos if cached_cos is None else None,
 | 
			
		||||
                    sin=sin if cached_sin is None else None,
 | 
			
		||||
                )
 | 
			
		||||
                hidden_states = layer_outputs[0]
 | 
			
		||||
                dist.send(hidden_states, dst=(rank + 1) % world_size)
 | 
			
		||||
                if cached_cos is None:
 | 
			
		||||
                    dist.send(cos, dst=(rank + 1) % world_size)
 | 
			
		||||
                    dist.send(sin, dst=(rank + 1) % world_size)
 | 
			
		||||
                past_key_values = layer_outputs[1]
 | 
			
		||||
                new_keys = layer_outputs[2]
 | 
			
		||||
                new_values = layer_outputs[3]
 | 
			
		||||
| 
						 | 
				
			
			@ -717,6 +767,8 @@ class DecodeRunner:
 | 
			
		|||
        output_attentions: bool = False,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
        cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
        cos: Optional[torch.Tensor] = None,
 | 
			
		||||
        sin: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -727,9 +779,18 @@ class DecodeRunner:
 | 
			
		|||
                self.input_queues[i].put(past_key_value)
 | 
			
		||||
        dist.broadcast(self.forward_signal, src=0, async_op=True)
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float16)
 | 
			
		||||
        if cos is not None:
 | 
			
		||||
            cos = cos.to(torch.float16)
 | 
			
		||||
            sin = sin.to(torch.float16)
 | 
			
		||||
        dist.send(hidden_states, dst=1)
 | 
			
		||||
        if cos is not None:
 | 
			
		||||
            dist.send(cos, dst=1)
 | 
			
		||||
            dist.send(sin, dst=1)
 | 
			
		||||
        past_key_value.expand(self.transpose_value_cache)
 | 
			
		||||
        dist.recv(hidden_states, src=self.world_size - 1)
 | 
			
		||||
        if cos is not None:
 | 
			
		||||
            dist.recv(cos, src=self.world_size - 1)
 | 
			
		||||
            dist.recv(sin, src=self.world_size - 1)
 | 
			
		||||
        return hidden_states, past_key_value
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -749,103 +810,113 @@ def run_prefill(
 | 
			
		|||
    model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    layer_start = 0
 | 
			
		||||
    layer_end = len(model.model.layers)
 | 
			
		||||
    num_heads = model.model.layers[layer_start].self_attn.num_heads
 | 
			
		||||
    num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
 | 
			
		||||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = []
 | 
			
		||||
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
        layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
			
		||||
        layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
        new_decoderlayer = FusedLlamaLowBitDecoderlayer(
 | 
			
		||||
            weights,
 | 
			
		||||
            num_heads=num_heads,
 | 
			
		||||
            num_key_value_heads=num_key_value_heads,
 | 
			
		||||
            cached_cos=cached_cos,
 | 
			
		||||
            cached_sin=cached_sin,
 | 
			
		||||
            layer_norm_0=layer_norm_0,
 | 
			
		||||
            layer_norm_1=layer_norm_1,
 | 
			
		||||
            layer_idx=layer_idx,
 | 
			
		||||
            rms_norm_eps=rms_norm_eps,
 | 
			
		||||
            intermediate_size=intermediate_size,
 | 
			
		||||
            max_seq_len=max_output_len,
 | 
			
		||||
            transpose_value=transpose_value_cache,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        layer_weights.extend(weights)
 | 
			
		||||
        input_layer_norm_weights.append(layer_norm_0)
 | 
			
		||||
        post_attn_layernorm_weights.append(layer_norm_1)
 | 
			
		||||
        model.model.layers[layer_idx] = new_decoderlayer
 | 
			
		||||
        deocderlayers.append(new_decoderlayer)
 | 
			
		||||
 | 
			
		||||
    print("finish creating all decode layers in prefill")
 | 
			
		||||
    result_queue.put("loading finish")
 | 
			
		||||
    deocderlayers = None
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
 | 
			
		||||
        result = input_queue.get()
 | 
			
		||||
        if result == "stop":
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        hidden_states, position_ids, causal_mask, past_key_values, cache_position = result
 | 
			
		||||
        hidden_states, position_ids, causal_mask, past_key_values, cache_position, cos, sin = result
 | 
			
		||||
 | 
			
		||||
        if deocderlayers is None:
 | 
			
		||||
            cos_len = cos.shape[1] if cos is not None else None
 | 
			
		||||
            layer_start = 0
 | 
			
		||||
            layer_end = len(model.model.layers)
 | 
			
		||||
            num_heads = model.model.layers[layer_start].self_attn.num_heads
 | 
			
		||||
            num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads
 | 
			
		||||
            head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
            rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
            intermediate_size = model.config.intermediate_size
 | 
			
		||||
            group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
            deocderlayers = []
 | 
			
		||||
            layer_weights = []
 | 
			
		||||
            input_layer_norm_weights = []
 | 
			
		||||
            post_attn_layernorm_weights = []
 | 
			
		||||
            layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
            n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
            n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
            for layer_idx in layer_indexs:
 | 
			
		||||
                curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
                attn_layer = curr_layer.self_attn
 | 
			
		||||
                mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
                weights = []
 | 
			
		||||
 | 
			
		||||
                if n_splits_linear == 1:
 | 
			
		||||
                    for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                                attn_layer.k_proj_dq_list,
 | 
			
		||||
                                                attn_layer.v_proj_dq_list,
 | 
			
		||||
                                                attn_layer.o_proj_dq_list,
 | 
			
		||||
                                                mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                                mlp_layer.up_proj_dq_list):
 | 
			
		||||
                        weights.append((q.weight, q.scale))
 | 
			
		||||
                        weights.append((k.weight, k.scale))
 | 
			
		||||
                        weights.append((v.weight, v.scale))
 | 
			
		||||
                        weights.append((o.weight, o.scale))
 | 
			
		||||
                        weights.append((g.weight, g.scale))
 | 
			
		||||
                        weights.append((u.weight, u.scale))
 | 
			
		||||
                else:
 | 
			
		||||
                    for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                                       attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                                       mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                        l_weights = []
 | 
			
		||||
                        scales = []
 | 
			
		||||
                        for l in layer_list:
 | 
			
		||||
                            l_weights.append(l.weight)
 | 
			
		||||
                            scales.append(l.scale)
 | 
			
		||||
                        weights.append((torch.stack(l_weights, axis=0),
 | 
			
		||||
                                        torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
                if n_splits_down_proj == 1:
 | 
			
		||||
                    for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                        weights.append((l.weight, l.scale))
 | 
			
		||||
                else:
 | 
			
		||||
                    l_weights = []
 | 
			
		||||
                    scales = []
 | 
			
		||||
                    for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                        l_weights.append(l.weight)
 | 
			
		||||
                        scales.append(l.scale)
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
                if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
                    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
                    cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
                else:
 | 
			
		||||
                    cached_cos = None
 | 
			
		||||
                    cached_sin = None
 | 
			
		||||
 | 
			
		||||
                layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
			
		||||
                layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
                new_decoderlayer = FusedLlamaLowBitDecoderlayer(
 | 
			
		||||
                    weights,
 | 
			
		||||
                    num_heads=num_heads,
 | 
			
		||||
                    num_key_value_heads=num_key_value_heads,
 | 
			
		||||
                    cached_cos=cached_cos,
 | 
			
		||||
                    cached_sin=cached_sin,
 | 
			
		||||
                    layer_norm_0=layer_norm_0,
 | 
			
		||||
                    layer_norm_1=layer_norm_1,
 | 
			
		||||
                    layer_idx=layer_idx,
 | 
			
		||||
                    rms_norm_eps=rms_norm_eps,
 | 
			
		||||
                    intermediate_size=intermediate_size,
 | 
			
		||||
                    max_seq_len=max_output_len,
 | 
			
		||||
                    transpose_value=transpose_value_cache,
 | 
			
		||||
                    n_splits_linear=n_splits_linear,
 | 
			
		||||
                    n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                    group_size=group_size,
 | 
			
		||||
                    cos_len=cos_len,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                layer_weights.extend(weights)
 | 
			
		||||
                input_layer_norm_weights.append(layer_norm_0)
 | 
			
		||||
                post_attn_layernorm_weights.append(layer_norm_1)
 | 
			
		||||
                model.model.layers[layer_idx] = new_decoderlayer
 | 
			
		||||
                deocderlayers.append(new_decoderlayer)
 | 
			
		||||
 | 
			
		||||
            print("finish creating all decode layers in prefill")
 | 
			
		||||
            result_queue.put("loading finish")
 | 
			
		||||
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            for decoder_layer in deocderlayers:
 | 
			
		||||
                layer_outputs = decoder_layer(
 | 
			
		||||
| 
						 | 
				
			
			@ -856,6 +927,8 @@ def run_prefill(
 | 
			
		|||
                    output_attentions=False,
 | 
			
		||||
                    use_cache=True,
 | 
			
		||||
                    cache_position=cache_position,
 | 
			
		||||
                    cos=cos,
 | 
			
		||||
                    sin=sin,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                hidden_states = layer_outputs[0]
 | 
			
		||||
| 
						 | 
				
			
			@ -887,9 +960,6 @@ class PrefillRunner:
 | 
			
		|||
        )
 | 
			
		||||
        self.p.daemon = True
 | 
			
		||||
        self.p.start()
 | 
			
		||||
        output = self.prefill_result_queue.get()
 | 
			
		||||
        print(Fore.GREEN + f"prefill process output: {output}")
 | 
			
		||||
        print(Style.RESET_ALL)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -900,6 +970,8 @@ class PrefillRunner:
 | 
			
		|||
        output_attentions: bool = False,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
        cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
        cos=None,
 | 
			
		||||
        sin=None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        seq_len = hidden_states.size(1)
 | 
			
		||||
| 
						 | 
				
			
			@ -919,9 +991,16 @@ class PrefillRunner:
 | 
			
		|||
            value=torch.iinfo(torch.int64).min,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position)
 | 
			
		||||
        args = (hidden_states, position_ids, attention_mask, past_key_value,
 | 
			
		||||
                cache_position, cos, sin)
 | 
			
		||||
        self.prefill_input_queue.put(args)
 | 
			
		||||
        hidden_states, past_key_value = self.prefill_result_queue.get()
 | 
			
		||||
 | 
			
		||||
        output = self.prefill_result_queue.get()
 | 
			
		||||
        if output == "loading finish":
 | 
			
		||||
            hidden_states, past_key_value = self.prefill_result_queue.get()
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states, past_key_value = output
 | 
			
		||||
 | 
			
		||||
        past_key_value.shrink(seq_len, self.transpose_value_cache)
 | 
			
		||||
        hidden_states = hidden_states[:, :seq_len, :]
 | 
			
		||||
        return hidden_states, past_key_value
 | 
			
		||||
| 
						 | 
				
			
			@ -1051,6 +1130,125 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
    return llama_fused_model_forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gen_llama_32_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		||||
 | 
			
		||||
    def llama_32_fused_model_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
        cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        output_attentions = (
 | 
			
		||||
            output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        )
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states
 | 
			
		||||
            if output_hidden_states is not None
 | 
			
		||||
            else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        if (input_ids is None) ^ (inputs_embeds is not None):
 | 
			
		||||
            msg = (
 | 
			
		||||
                "You cannot specify both input_ids and inputs_embeds at the same time,"
 | 
			
		||||
                " and must specify either one"
 | 
			
		||||
            )
 | 
			
		||||
            invalidInputError(False, msg)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training and use_cache:
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
 | 
			
		||||
        # ipex-llm changes start
 | 
			
		||||
        from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
 | 
			
		||||
        if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
 | 
			
		||||
            past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if cache_position is None:
 | 
			
		||||
            past_seen_tokens = past_key_values.get_seq_length() \
 | 
			
		||||
                if past_key_values is not None else 0
 | 
			
		||||
            cache_position = torch.arange(
 | 
			
		||||
                past_seen_tokens,
 | 
			
		||||
                past_seen_tokens + inputs_embeds.shape[1],
 | 
			
		||||
                device=inputs_embeds.device
 | 
			
		||||
            )
 | 
			
		||||
        # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
        if position_ids is None:
 | 
			
		||||
            position_ids = cache_position.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        causal_mask = self._update_causal_mask(
 | 
			
		||||
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # embed positions
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        # create position embeddings to be shared across the decoder layers
 | 
			
		||||
        position_embeddings = self.rotary_emb(hidden_states, position_ids)
 | 
			
		||||
        cos, sin = position_embeddings
 | 
			
		||||
 | 
			
		||||
        # decoder layers
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_self_attns = () if output_attentions else None
 | 
			
		||||
        next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
        seq_len = hidden_states.size(1)
 | 
			
		||||
        if seq_len == 1:
 | 
			
		||||
            layers_runner = decode_runner
 | 
			
		||||
        else:
 | 
			
		||||
            layers_runner = prefill_runner
 | 
			
		||||
 | 
			
		||||
        layer_outputs = layers_runner.forward(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            attention_mask=causal_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            past_key_value=past_key_values,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            cache_position=cache_position,
 | 
			
		||||
            cos=cos,
 | 
			
		||||
            sin=sin,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
        next_decoder_cache = layer_outputs[1]
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        # ipex-llm changes start
 | 
			
		||||
        next_cache = next_decoder_cache if use_cache else None
 | 
			
		||||
        # ipex-llm changes end
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(
 | 
			
		||||
                v
 | 
			
		||||
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
			
		||||
                if v is not None
 | 
			
		||||
            )
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_self_attns,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return llama_32_fused_model_forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama2_casullm_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -498,11 +498,12 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
 | 
			
		||||
    def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
 | 
			
		||||
                             num_heads, seq_len, head_dim):
 | 
			
		||||
        position_ids = self.squeeze(position_ids)
 | 
			
		||||
        cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
			
		||||
        sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
			
		||||
        cos = self.unsqueeze(cos, [1])
 | 
			
		||||
        sin = self.unsqueeze(sin, [1])
 | 
			
		||||
        if position_ids is not None:
 | 
			
		||||
            position_ids = self.squeeze(position_ids)
 | 
			
		||||
            cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
			
		||||
            sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
			
		||||
            cos = self.unsqueeze(cos, [1])
 | 
			
		||||
            sin = self.unsqueeze(sin, [1])
 | 
			
		||||
 | 
			
		||||
        rotate_half_q = self.rotate_half(q,
 | 
			
		||||
                                         num_heads=num_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue