fix first token sdp with batch (#11153)
This commit is contained in:
parent
3464440839
commit
d307622797
5 changed files with 23 additions and 16 deletions
|
|
@ -137,9 +137,10 @@ def baichuan_attention_forward_7b(
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
value_states)
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
|
||||||
|
|
@ -398,18 +398,20 @@ def internlm_xcomposser2_attention_forward(
|
||||||
|
|
||||||
# IPEX-LLM OPT: sdp
|
# IPEX-LLM OPT: sdp
|
||||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
import xe_linear
|
import xe_addons
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_linear.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_linear.sdp(query_states, key_states, value_states, attention_mask)
|
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_linear
|
import xe_addons
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_linear.sdp_fp8_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_linear.sdp_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
|
||||||
|
|
@ -142,9 +142,11 @@ def attention_forward(
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value, DynamicFp8Cache):
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value, DynamicFp8Cache):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
|
||||||
|
|
@ -127,9 +127,9 @@ def qwen_attention_forward(
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None)
|
||||||
else:
|
else:
|
||||||
if q_len > 1:
|
if q_len > 1:
|
||||||
causal_mask = torch.tril(
|
causal_mask = torch.tril(
|
||||||
|
|
@ -256,9 +256,9 @@ def qwen_attention_forward_registered(
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None)
|
||||||
else:
|
else:
|
||||||
if q_len > 1:
|
if q_len > 1:
|
||||||
causal_mask = registered_causal_mask[
|
causal_mask = registered_causal_mask[
|
||||||
|
|
|
||||||
|
|
@ -347,9 +347,11 @@ def qwen2_attention_forward(
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value, DynamicFp8Cache):
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value, DynamicFp8Cache):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue