refactor qwen2 and llama3 (#12587)
This commit is contained in:
		
							parent
							
								
									51ff9ebd8a
								
							
						
					
					
						commit
						f3b5fad3be
					
				
					 4 changed files with 16 additions and 103 deletions
				
			
		| 
						 | 
				
			
			@ -37,7 +37,6 @@ from typing import Optional, Tuple
 | 
			
		|||
import torch
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_fused_layer_norm
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,14 +42,12 @@ import torch
 | 
			
		|||
from typing import Optional, Tuple, Union
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from transformers.models.llama.modeling_llama import repeat_kv
 | 
			
		||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.common import attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_compresskv, \
 | 
			
		||||
    is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \
 | 
			
		||||
| 
						 | 
				
			
			@ -233,44 +231,11 @@ def llama_attention_forward(
 | 
			
		|||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                             self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.size(2)
 | 
			
		||||
    if attention_mask is not None:  # no matter the length, we just slice it
 | 
			
		||||
        causal_mask = attention_mask[:, :, :, :kv_seq_len]
 | 
			
		||||
    else:
 | 
			
		||||
        causal_mask = None
 | 
			
		||||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, causal_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, causal_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        if causal_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + causal_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == key_states.size(2), math.sqrt(self.head_dim)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, -1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,11 +46,12 @@ from torch.nn import CrossEntropyLoss
 | 
			
		|||
from torch.nn.functional import scaled_dot_product_attention as sdpa
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
 | 
			
		||||
    should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, \
 | 
			
		||||
    should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \
 | 
			
		||||
    DynamicCompressCache, DynamicCompressFp8Cache
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
| 
						 | 
				
			
			@ -532,7 +533,6 @@ def qwen2_attention_forward(
 | 
			
		|||
    # [CompressKV]
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
    use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
 | 
			
		||||
 | 
			
		||||
    if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
 | 
			
		||||
        qkv = self.qkv_proj(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -583,18 +583,8 @@ def qwen2_attention_forward(
 | 
			
		|||
                                                             self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if query_states.device.type == "cpu":
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
        attn_output = sdpa(query_states,
 | 
			
		||||
                           key_states,
 | 
			
		||||
                           value_states,
 | 
			
		||||
                           attn_mask=attention_mask,
 | 
			
		||||
                           dropout_p=self.attention_dropout if self.training else 0.0,
 | 
			
		||||
                           is_causal=self.is_causal and attention_mask is None and q_len > 1)
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
    if query_states.device.type == 'xpu' \
 | 
			
		||||
            and use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
| 
						 | 
				
			
			@ -602,42 +592,11 @@ def qwen2_attention_forward(
 | 
			
		|||
                           key_states.to(device, dtype=torch.float16),
 | 
			
		||||
                           value_states.to(device, dtype=torch.float16),
 | 
			
		||||
                           is_causal=True).to(hidden_states.dtype)
 | 
			
		||||
    elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        if use_quantizekv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantizekv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantizekv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                                   dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                                   training=self.training)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        attn_output = scaled_dot_product_attention(
 | 
			
		||||
            query_states, key_states, value_states,
 | 
			
		||||
            attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -358,16 +358,6 @@ def use_xmx(x: torch.Tensor, qtype: int):
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_fused_layer_norm(x: torch.Tensor, training: bool):
 | 
			
		||||
    device = get_xpu_device_type(x)
 | 
			
		||||
    return (
 | 
			
		||||
        not training
 | 
			
		||||
        and not x.requires_grad
 | 
			
		||||
        and device in ["arc", "flex", "pvc", "mtl", "lnl"]  # fused layer norm cannot run on UHD
 | 
			
		||||
        and x.numel() // x.size(-1) == 1  # fused layer norm is slower in first token
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fp16_fusion_check(proj, x, training):
 | 
			
		||||
    # only use fp16 fusion on PVC inference
 | 
			
		||||
    if proj is None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue