From fad1dbaf6008595283c3b6ced96fa85335da828d Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 15 May 2024 10:22:35 +0800 Subject: [PATCH] use sdp fp8 causal kernel (#11023) --- .../src/ipex_llm/transformers/models/phi3.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 98fba35e..d1451666 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -137,38 +137,39 @@ def attention_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, None) - if (isinstance(past_key_value, DynamicFp8Cache) and - use_sdp_fp8(q_len, kv_seq_len, query_states)): + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import linear_q4_0 - attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, attention_mask) - elif (isinstance(past_key_value, DynamicNormalCache) and - use_sdp(q_len, kv_seq_len, self.head_dim, query_states)): + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) + elif use_sdp_causal(q_len, kv_seq_len, query_states, self.training): import linear_q4_0 - attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states) + else: + attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states) else: if isinstance(past_key_value, DynamicFp8Cache): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) - if use_sdp_causal(q_len, kv_seq_len, query_states, self.training): - import linear_q4_0 - attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states) - else: - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)