fix qwen import error (#11154)

This commit is contained in:
Yina Chen 2024-05-28 14:50:12 +08:00 committed by GitHub
parent 25b6402315
commit 3464440839
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -222,8 +222,8 @@ def qwen_attention_forward_registered(
rotary_pos_emb = rotary_pos_emb_list[0] rotary_pos_emb = rotary_pos_emb_list[0]
if use_fuse_rope: if use_fuse_rope:
rot_dim = rotary_pos_emb[0].size(-1) rot_dim = rotary_pos_emb[0].size(-1)
import xe_linear import xe_addons
xe_linear.rotary_half_inplaced(inv_freq, position_ids, xe_addons.rotary_half_inplaced(inv_freq, position_ids,
query_states[..., :rot_dim], key_states[..., :rot_dim]) query_states[..., :rot_dim], key_states[..., :rot_dim])
else: else:
rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb] rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb]
@ -254,11 +254,11 @@ def qwen_attention_forward_registered(
value_states.to(dtype=torch.float16), value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype) is_causal=True).to(hidden_states.dtype)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_linear import xe_addons
if use_quantize_kv: if use_quantize_kv:
attn_output = xe_linear.sdp_fp8_causal(query_states, key_states, value_states) attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states)
else: else:
attn_output = xe_linear.sdp_causal(query_states, key_states, value_states) attn_output = xe_addons.sdp_causal(query_states, key_states, value_states)
else: else:
if q_len > 1: if q_len > 1:
causal_mask = registered_causal_mask[ causal_mask = registered_causal_mask[
@ -273,12 +273,12 @@ def qwen_attention_forward_registered(
attention_mask = None attention_mask = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_linear import xe_addons
if use_quantize_kv: if use_quantize_kv:
attn_output = xe_linear.sdp_fp8(query_states, key_states, value_states, attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask) attention_mask)
else: else:
attn_output = xe_linear.sdp(query_states, key_states, value_states, attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask) attention_mask)
else: else:
if use_quantize_kv: if use_quantize_kv: