From 49e1381c7f5ee9445d9ec018fbf4e594a0ce8281 Mon Sep 17 00:00:00 2001 From: Jiao Wang Date: Sun, 15 Oct 2023 21:51:45 -0700 Subject: [PATCH] update rope (#9155) --- .../src/bigdl/llm/transformers/models/gptneox.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/gptneox.py b/python/llm/src/bigdl/llm/transformers/models/gptneox.py index 1f70491f..2a47d6e9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptneox.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptneox.py @@ -35,6 +35,7 @@ import torch from typing import Optional, Tuple from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -80,8 +81,19 @@ def gptneox_attention_forward( seq_len = key.shape[-2] if has_layer_past: seq_len += layer_past[0].shape[-2] - cos, sin = self.rotary_emb(value, seq_len=seq_len) - query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox") + + use_fuse_rope = query.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad) + + if use_fuse_rope: + query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot, + key_rot, + position_ids, + "gpt_neox") + else: + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox") + query = torch.cat((query, query_pass), dim=-1) key = torch.cat((key, key_pass), dim=-1)