This commit is contained in:
Xin Qiu 2024-01-23 15:53:32 +08:00 committed by GitHub
parent 052962dfa5
commit da4687c917

View file

@ -97,6 +97,8 @@ def qwen_attention_forward(
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
if use_fuse_rope:
cos, sin = rotary_pos_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
else:
rotary_pos_emb = (rotary_pos_emb,) * 2
@ -111,6 +113,8 @@ def qwen_attention_forward(
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
if use_fuse_rope:
cos, sin = rotary_pos_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
query_list += [query]
key_list += [key]