diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 47b0701e..dbcce5a4 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -42,6 +42,7 @@ from ipex_llm.transformers.models.utils import ( from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.models.utils import use_new_esimd_sdp_fp16, use_quantize_kv_cache from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_sdp_causal from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from typing import Optional, Tuple, List @@ -148,22 +149,26 @@ def attention_forward( if isinstance(past_key_value, DynamicFp8Cache): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if use_sdp_causal(q_len, kv_seq_len, query_states, self.training): + import linear_q4_0 + attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states) + else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 21412806..c7a29bc8 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -384,6 +384,14 @@ def use_sdp_fp8(q_len, k_len, query_states): return True +def use_sdp_causal(q_len, kv_len, query_states, training): + return ( + q_len == kv_len # first token + and query_states.device.type == "xpu" # GPU + and not query_states.requires_grad and not training # no training + ) + + def mlp_fusion_check(x, qtype, training): invalidInputError(x.dim() == 2, "Here input x's dim should be 2.")