Fix llama when rope scaling is not None (#9086)
* Fix llama when rope scaling is not None * fix style * fix style
This commit is contained in:
parent
fcb1c618a0
commit
36dd4afd61
1 changed files with 5 additions and 1 deletions
|
|
@ -118,7 +118,11 @@ def llama_attention_forward_4_31(
|
|||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
|
||||
use_fuse_rope = query_states.device.type == "xpu"
|
||||
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
|
||||
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
|
||||
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
|
|
|
|||
Loading…
Reference in a new issue