From aa2fa9fde18e424c96b3ce34d5337fd28c43c0eb Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 7 May 2024 15:53:08 +0800 Subject: [PATCH] optimize phi3 again: use sdp if possible (#10951) --- .../src/ipex_llm/transformers/models/phi3.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 01c0c34c..a293a37c 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -39,7 +39,7 @@ from ipex_llm.transformers.models.utils import ( rotate_half, should_use_fuse_rope, apply_rotary_pos_emb_cache_freq_xpu ) -from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU +from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU, use_new_esimd_sdp_fp16 from ipex_llm.transformers.kv import DynamicNormalCache from typing import Optional, Tuple, List @@ -93,22 +93,26 @@ def attention_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, None) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states): + import linear_q4_0 + attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) + else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)