Fix llama when rope scaling is not None (#9086)

* Fix llama when rope scaling is not None

* fix style

* fix style
This commit is contained in:
Yang Wang 2023-10-07 04:27:37 +08:00 committed by GitHub
parent fcb1c618a0
commit 36dd4afd61

View file

@ -118,7 +118,11 @@ def llama_attention_forward_4_31(
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states, key_states,
position_ids, position_ids,