diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index c0d642e2..cf0eafce 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -97,6 +97,8 @@ def qwen_attention_forward( rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] if use_fuse_rope: cos, sin = rotary_pos_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") else: rotary_pos_emb = (rotary_pos_emb,) * 2 @@ -111,6 +113,8 @@ def qwen_attention_forward( rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] if use_fuse_rope: cos, sin = rotary_pos_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") query_list += [query] key_list += [key]