Fix several models based on sdp api change (#13075)
* fix baichuan based on sdp api change * fix several models based on api change * fix style
This commit is contained in:
parent
7826152f5a
commit
e08c6bd018
3 changed files with 11 additions and 4 deletions
|
|
@ -326,14 +326,17 @@ def baichuan_attention_forward_13b(
|
|||
else:
|
||||
attention_mask = attention_mask[None, :, -q_len:, :]
|
||||
|
||||
head_dim = query_states.shape[-1]
|
||||
scale = 1 / math.sqrt(head_dim)
|
||||
|
||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import xe_addons
|
||||
if use_quantize_kv:
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attention_mask, scale)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attention_mask, scale)
|
||||
attn_weights = None
|
||||
else:
|
||||
if use_quantize_kv:
|
||||
|
|
|
|||
|
|
@ -68,7 +68,9 @@ def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
|
|||
if use_sdp(query.shape[2], key.shape[2],
|
||||
query.shape[-1], query):
|
||||
import xe_addons
|
||||
attn_output = xe_addons.sdp(query, key, value, attn_bias)
|
||||
head_dim = query.shape[-1]
|
||||
scale = 1 / math.sqrt(head_dim)
|
||||
attn_output = xe_addons.sdp(query, key, value, attn_bias, scale)
|
||||
context_layer = attn_output.view(query.shape)
|
||||
else:
|
||||
head_dim = query.size(-1)
|
||||
|
|
|
|||
|
|
@ -164,7 +164,9 @@ def qwen_attention_forward_vl(
|
|||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_sdp(q_len, key.shape[2], self.head_dim, query):
|
||||
import xe_addons
|
||||
attn_output = xe_addons.sdp(query, key, value, attention_mask)
|
||||
head_dim = query.shape[-1]
|
||||
scale = 1 / math.sqrt(head_dim)
|
||||
attn_output = xe_addons.sdp(query, key, value, attention_mask, scale)
|
||||
attn_output = attn_output.view(query.shape)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weight = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue