Fix llama when rope scaling is not None (#9086)
* Fix llama when rope scaling is not None * fix style * fix style
This commit is contained in:
parent
fcb1c618a0
commit
36dd4afd61
1 changed files with 5 additions and 1 deletions
|
|
@ -118,7 +118,11 @@ def llama_attention_forward_4_31(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
|
use_fuse_rope = query_states.device.type == "xpu"
|
||||||
|
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
|
||||||
|
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
|
||||||
|
|
||||||
|
if use_fuse_rope:
|
||||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue