diff --git a/python/llm/src/bigdl/llm/transformers/models/gptneox.py b/python/llm/src/bigdl/llm/transformers/models/gptneox.py index 1f70491f..2a47d6e9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptneox.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptneox.py @@ -35,6 +35,7 @@ import torch from typing import Optional, Tuple from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -80,8 +81,19 @@ def gptneox_attention_forward( seq_len = key.shape[-2] if has_layer_past: seq_len += layer_past[0].shape[-2] - cos, sin = self.rotary_emb(value, seq_len=seq_len) - query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox") + + use_fuse_rope = query.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad) + + if use_fuse_rope: + query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot, + key_rot, + position_ids, + "gpt_neox") + else: + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox") + query = torch.cat((query, query_pass), dim=-1) key = torch.cat((key, key_pass), dim=-1)