refactor yuan2 and starcoder2 and fix (#12589)

This commit is contained in:
Yishuo Wang 2024-12-20 16:41:50 +08:00 committed by GitHub
parent 6ea8033635
commit b050368efc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 28 additions and 83 deletions

View file

@ -234,7 +234,7 @@ def llama_attention_forward(
attn_weights = None attn_weights = None
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
query_states, key_states, value_states, query_states, key_states, value_states,
attention_mask, q_len == key_states.size(2), math.sqrt(self.head_dim) attention_mask, q_len == key_states.size(2)
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()

View file

@ -38,15 +38,13 @@
import torch import torch
import warnings import warnings
import torch.nn as nn
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
import math import math
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope
from ipex_llm.transformers.models.llama import repeat_kv
from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \ from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
DynamicCompressCache, DynamicCompressFp8Cache DynamicCompressCache, DynamicCompressFp8Cache
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
@ -127,11 +125,10 @@ def minicpm_attention_forward(
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None) self.layer_idx, None)
from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_weights = None attn_weights = None
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
query_states, key_states, value_states, query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim) attention_mask, q_len == kv_seq_len
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()

View file

@ -28,6 +28,7 @@ from typing import Optional, List
from torch.nn.functional import linear from torch.nn.functional import linear
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from transformers import AutoProcessor, TextIteratorStreamer from transformers import AutoProcessor, TextIteratorStreamer
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
@ -72,10 +73,11 @@ def siglip_attention_forward(
72, 80 72, 80
) )
from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_weights = None attn_weights = None
attn_output = scaled_dot_product_attention(query_states, key_states, value_states, attn_output = scaled_dot_product_attention(
attention_mask, False, math.sqrt(self.head_dim)) query_states, key_states, value_states,
attention_mask, False, 1 / math.sqrt(self.head_dim)
)
attn_output = attn_output[:, :, :, :self.head_dim] attn_output = attn_output[:, :, :, :self.head_dim]

View file

@ -595,7 +595,7 @@ def qwen2_attention_forward(
else: else:
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
query_states, key_states, value_states, query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim) attention_mask, q_len == kv_seq_len
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()

View file

@ -40,17 +40,15 @@ import math
import torch import torch
import warnings import warnings
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import ( from ipex_llm.transformers.models.common import scaled_dot_product_attention
use_quantize_kv_cache, restore_fp8_kv_cache, from ipex_llm.transformers.models.utils import use_quantize_kv_cache, should_use_fuse_rope
should_use_fuse_rope, use_sdp, use_sdp_causal
)
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.utils.common.log4Error import invalidInputError
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.models.starcoder2.modeling_starcoder2 import repeat_kv, apply_rotary_pos_emb from transformers.models.starcoder2.modeling_starcoder2 import apply_rotary_pos_emb
from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention
@ -103,41 +101,11 @@ def attention_forward(
self.layer_idx, None) self.layer_idx, None)
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_weights = None
import xe_addons attn_output = scaled_dot_product_attention(
if isinstance(past_key_value, DynamicFp8Cache): query_states, key_states, value_states,
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask, q_len == kv_seq_len
attention_mask) )
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View file

@ -26,12 +26,12 @@ from typing import Optional, Tuple
import torch import torch
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
mlp_fusion_check, fp16_fusion_check mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import SILU, update_past_key_value from ipex_llm.transformers.models.utils import SILU, update_past_key_value
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import should_use_fuse_rope
def merge_qk(module: torch.nn.Module): def merge_qk(module: torch.nn.Module):
@ -214,34 +214,12 @@ def yuan_attention_forward(
) )
past_key_value = (key_states, value_states, before_hidden_states) if use_cache else None past_key_value = (key_states, value_states, before_hidden_states) if use_cache else None
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdpa
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_weights = None
import xe_addons attn_output = scaled_dot_product_attention(
if use_quantize_kv: query_states, key_states, value_states,
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask, q_len == kv_seq_len
attention_mask) )
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)