diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 99f1d726..bb22185c 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -222,8 +222,8 @@ def qwen_attention_forward_registered( rotary_pos_emb = rotary_pos_emb_list[0] if use_fuse_rope: rot_dim = rotary_pos_emb[0].size(-1) - import xe_linear - xe_linear.rotary_half_inplaced(inv_freq, position_ids, + import xe_addons + xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states[..., :rot_dim], key_states[..., :rot_dim]) else: rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb] @@ -254,11 +254,11 @@ def qwen_attention_forward_registered( value_states.to(dtype=torch.float16), is_causal=True).to(hidden_states.dtype) elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_linear + import xe_addons if use_quantize_kv: - attn_output = xe_linear.sdp_fp8_causal(query_states, key_states, value_states) + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states) else: - attn_output = xe_linear.sdp_causal(query_states, key_states, value_states) + attn_output = xe_addons.sdp_causal(query_states, key_states, value_states) else: if q_len > 1: causal_mask = registered_causal_mask[ @@ -273,12 +273,12 @@ def qwen_attention_forward_registered( attention_mask = None if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_linear + import xe_addons if use_quantize_kv: - attn_output = xe_linear.sdp_fp8(query_states, key_states, value_states, + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) else: - attn_output = xe_linear.sdp(query_states, key_states, value_states, + attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) else: if use_quantize_kv: