From da4687c9179a70db1ef16548d2d4d0cac8051414 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Tue, 23 Jan 2024 15:53:32 +0800 Subject: [PATCH] fix fp16 (#9970) --- python/llm/src/bigdl/llm/transformers/models/qwen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index c0d642e2..cf0eafce 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -97,6 +97,8 @@ def qwen_attention_forward( rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] if use_fuse_rope: cos, sin = rotary_pos_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") else: rotary_pos_emb = (rotary_pos_emb,) * 2 @@ -111,6 +113,8 @@ def qwen_attention_forward( rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] if use_fuse_rope: cos, sin = rotary_pos_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") query_list += [query] key_list += [key]