optimize phi3 again: use sdp if possible (#10951)
This commit is contained in:
parent
c11170b96f
commit
aa2fa9fde1
1 changed files with 18 additions and 14 deletions
|
|
@ -39,7 +39,7 @@ from ipex_llm.transformers.models.utils import (
|
|||
rotate_half, should_use_fuse_rope,
|
||||
apply_rotary_pos_emb_cache_freq_xpu
|
||||
)
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU, use_new_esimd_sdp_fp16
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
|
@ -93,22 +93,26 @@ def attention_forward(
|
|||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
if use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
else:
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(value_states.dtype)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(value_states.dtype)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
|
|
|||
Loading…
Reference in a new issue