fix fp16 (#9970)
This commit is contained in:
parent
052962dfa5
commit
da4687c917
1 changed files with 4 additions and 0 deletions
|
|
@ -97,6 +97,8 @@ def qwen_attention_forward(
|
|||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
if use_fuse_rope:
|
||||
cos, sin = rotary_pos_emb
|
||||
cos = cos.to(query.dtype)
|
||||
sin = sin.to(query.dtype)
|
||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
||||
else:
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
|
|
@ -111,6 +113,8 @@ def qwen_attention_forward(
|
|||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
if use_fuse_rope:
|
||||
cos, sin = rotary_pos_emb
|
||||
cos = cos.to(query.dtype)
|
||||
sin = sin.to(query.dtype)
|
||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
||||
query_list += [query]
|
||||
key_list += [key]
|
||||
|
|
|
|||
Loading…
Reference in a new issue