fix qwen import error (#11154)
This commit is contained in:
parent
25b6402315
commit
3464440839
1 changed files with 8 additions and 8 deletions
|
|
@ -222,8 +222,8 @@ def qwen_attention_forward_registered(
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
if use_fuse_rope:
|
if use_fuse_rope:
|
||||||
rot_dim = rotary_pos_emb[0].size(-1)
|
rot_dim = rotary_pos_emb[0].size(-1)
|
||||||
import xe_linear
|
import xe_addons
|
||||||
xe_linear.rotary_half_inplaced(inv_freq, position_ids,
|
xe_addons.rotary_half_inplaced(inv_freq, position_ids,
|
||||||
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
||||||
else:
|
else:
|
||||||
rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb]
|
rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb]
|
||||||
|
|
@ -254,11 +254,11 @@ def qwen_attention_forward_registered(
|
||||||
value_states.to(dtype=torch.float16),
|
value_states.to(dtype=torch.float16),
|
||||||
is_causal=True).to(hidden_states.dtype)
|
is_causal=True).to(hidden_states.dtype)
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_linear
|
import xe_addons
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_linear.sdp_fp8_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_linear.sdp_causal(query_states, key_states, value_states)
|
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
|
||||||
else:
|
else:
|
||||||
if q_len > 1:
|
if q_len > 1:
|
||||||
causal_mask = registered_causal_mask[
|
causal_mask = registered_causal_mask[
|
||||||
|
|
@ -273,12 +273,12 @@ def qwen_attention_forward_registered(
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
import xe_linear
|
import xe_addons
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_linear.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_linear.sdp(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue