refactor yuan2 and starcoder2 and fix (#12589)
This commit is contained in:
		
							parent
							
								
									6ea8033635
								
							
						
					
					
						commit
						b050368efc
					
				
					 6 changed files with 28 additions and 83 deletions
				
			
		| 
						 | 
				
			
			@ -234,7 +234,7 @@ def llama_attention_forward(
 | 
			
		|||
    attn_weights = None
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == key_states.size(2), math.sqrt(self.head_dim)
 | 
			
		||||
        attention_mask, q_len == key_states.size(2)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,15 +38,13 @@
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
import warnings
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from typing import Optional, Tuple, Union, List
 | 
			
		||||
import math
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.llama import repeat_kv
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
 | 
			
		||||
    DynamicCompressCache, DynamicCompressFp8Cache
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
| 
						 | 
				
			
			@ -127,11 +125,10 @@ def minicpm_attention_forward(
 | 
			
		|||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                             self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
 | 
			
		||||
        attention_mask, q_len == kv_seq_len
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,6 +28,7 @@ from typing import Optional, List
 | 
			
		|||
from torch.nn.functional import linear
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
 | 
			
		||||
from ipex_llm.transformers.models.common import attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from transformers import AutoProcessor, TextIteratorStreamer
 | 
			
		||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -72,10 +73,11 @@ def siglip_attention_forward(
 | 
			
		|||
            72, 80
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = scaled_dot_product_attention(query_states, key_states, value_states,
 | 
			
		||||
                                                   attention_mask, False, math.sqrt(self.head_dim))
 | 
			
		||||
        attn_output = scaled_dot_product_attention(
 | 
			
		||||
            query_states, key_states, value_states,
 | 
			
		||||
            attention_mask, False, 1 / math.sqrt(self.head_dim)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_output[:, :, :, :self.head_dim]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -595,7 +595,7 @@ def qwen2_attention_forward(
 | 
			
		|||
    else:
 | 
			
		||||
        attn_output = scaled_dot_product_attention(
 | 
			
		||||
            query_states, key_states, value_states,
 | 
			
		||||
            attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
 | 
			
		||||
            attention_mask, q_len == kv_seq_len
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,17 +40,15 @@ import math
 | 
			
		|||
import torch
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import (
 | 
			
		||||
    use_quantize_kv_cache, restore_fp8_kv_cache,
 | 
			
		||||
    should_use_fuse_rope, use_sdp, use_sdp_causal
 | 
			
		||||
)
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple, List
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
from transformers.models.starcoder2.modeling_starcoder2 import repeat_kv, apply_rotary_pos_emb
 | 
			
		||||
from transformers.models.starcoder2.modeling_starcoder2 import apply_rotary_pos_emb
 | 
			
		||||
from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -103,41 +101,11 @@ def attention_forward(
 | 
			
		|||
                                                     self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                                   training=self.training)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,12 +26,12 @@ from typing import Optional, Tuple
 | 
			
		|||
import torch
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.common import attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    mlp_fusion_check, fp16_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import SILU, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_qk(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			@ -214,34 +214,12 @@ def yuan_attention_forward(
 | 
			
		|||
    )
 | 
			
		||||
    past_key_value = (key_states, value_states, before_hidden_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    # IPEX-LLM OPT: sdpa
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue