update rope (#9155)

This commit is contained in:
Jiao Wang 2023-10-15 21:51:45 -07:00 committed by GitHub
parent b192a8032c
commit 49e1381c7f

View file

@ -35,6 +35,7 @@ import torch
from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -80,8 +81,19 @@ def gptneox_attention_forward(
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
use_fuse_rope = query.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
if use_fuse_rope:
query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot,
key_rot,
position_ids,
"gpt_neox")
else:
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)