update rope (#9155)
This commit is contained in:
parent
b192a8032c
commit
49e1381c7f
1 changed files with 14 additions and 2 deletions
|
|
@ -35,6 +35,7 @@ import torch
|
|||
from typing import Optional, Tuple
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -80,8 +81,19 @@ def gptneox_attention_forward(
|
|||
seq_len = key.shape[-2]
|
||||
if has_layer_past:
|
||||
seq_len += layer_past[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
|
||||
|
||||
use_fuse_rope = query.device.type == "xpu"
|
||||
use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
|
||||
|
||||
if use_fuse_rope:
|
||||
query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot,
|
||||
key_rot,
|
||||
position_ids,
|
||||
"gpt_neox")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
|
||||
|
||||
query = torch.cat((query, query_pass), dim=-1)
|
||||
key = torch.cat((key, key_pass), dim=-1)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue