Fix several models based on sdp api change (#13075)

* fix baichuan based on sdp api change

* fix several models based on api change

* fix style
This commit is contained in:
Ruonan Wang 2025-04-15 11:13:12 +08:00 committed by GitHub
parent 7826152f5a
commit e08c6bd018
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 11 additions and 4 deletions

View file

@ -326,14 +326,17 @@ def baichuan_attention_forward_13b(
else:
attention_mask = attention_mask[None, :, -q_len:, :]
head_dim = query_states.shape[-1]
scale = 1 / math.sqrt(head_dim)
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attention_mask, scale)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
attention_mask, scale)
attn_weights = None
else:
if use_quantize_kv:

View file

@ -68,7 +68,9 @@ def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
if use_sdp(query.shape[2], key.shape[2],
query.shape[-1], query):
import xe_addons
attn_output = xe_addons.sdp(query, key, value, attn_bias)
head_dim = query.shape[-1]
scale = 1 / math.sqrt(head_dim)
attn_output = xe_addons.sdp(query, key, value, attn_bias, scale)
context_layer = attn_output.view(query.shape)
else:
head_dim = query.size(-1)

View file

@ -164,7 +164,9 @@ def qwen_attention_forward_vl(
if not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key.shape[2], self.head_dim, query):
import xe_addons
attn_output = xe_addons.sdp(query, key, value, attention_mask)
head_dim = query.shape[-1]
scale = 1 / math.sqrt(head_dim)
attn_output = xe_addons.sdp(query, key, value, attention_mask, scale)
attn_output = attn_output.view(query.shape)
attn_output = attn_output.transpose(1, 2)
attn_weight = None