fix first token sdp with batch (#11153)

This commit is contained in:
Yishuo Wang 2024-05-28 15:03:06 +08:00 committed by GitHub
parent 3464440839
commit d307622797
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 23 additions and 16 deletions

View file

@ -137,9 +137,10 @@ def baichuan_attention_forward_7b(
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states)
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,

View file

@ -398,18 +398,20 @@ def internlm_xcomposser2_attention_forward(
# IPEX-LLM OPT: sdp
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_linear
import xe_addons
if use_quantize_kv:
attn_output = xe_linear.sdp_fp8(query_states, key_states, value_states,
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_linear.sdp(query_states, key_states, value_states, attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_linear
import xe_addons
if use_quantize_kv:
attn_output = xe_linear.sdp_fp8_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_linear.sdp_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,

View file

@ -142,9 +142,11 @@ def attention_forward(
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,

View file

@ -127,9 +127,9 @@ def qwen_attention_forward(
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None)
else:
if q_len > 1:
causal_mask = torch.tril(
@ -256,9 +256,9 @@ def qwen_attention_forward_registered(
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None)
else:
if q_len > 1:
causal_mask = registered_causal_mask[

View file

@ -347,9 +347,11 @@ def qwen2_attention_forward(
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,