From 6454655dcc9dfa2629211c7aebcfccb895268c5a Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 4 Jun 2024 15:39:00 +0800 Subject: [PATCH] use sdp in baichuan2 13b (#11198) --- .../ipex_llm/transformers/models/baichuan2.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index 31d982a7..4b450b5f 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -204,24 +204,25 @@ def baichuan_attention_forward_13b( else: attention_mask = attention_mask[:, None, -q_len:, :] - if use_quantize_kv and q_len == 1: + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons - attn_weights = xe_addons.query_key_fp8_matmul(query_states, key_states) + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, + attention_mask) + attn_weights = None else: if use_quantize_kv: key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) - attn_weights = attn_weights / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = attn_weights.to(query_states.dtype) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - if use_quantize_kv and q_len == 1: - import xe_addons - attn_output = xe_addons.attn_value_fp8_matmul(attn_weights, value_states) - else: + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = attn_weights.to(query_states.dtype) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)