update rope (#9155)
This commit is contained in:
parent
b192a8032c
commit
49e1381c7f
1 changed files with 14 additions and 2 deletions
|
|
@ -35,6 +35,7 @@ import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -80,8 +81,19 @@ def gptneox_attention_forward(
|
||||||
seq_len = key.shape[-2]
|
seq_len = key.shape[-2]
|
||||||
if has_layer_past:
|
if has_layer_past:
|
||||||
seq_len += layer_past[0].shape[-2]
|
seq_len += layer_past[0].shape[-2]
|
||||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
|
||||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
|
use_fuse_rope = query.device.type == "xpu"
|
||||||
|
use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
|
||||||
|
|
||||||
|
if use_fuse_rope:
|
||||||
|
query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot,
|
||||||
|
key_rot,
|
||||||
|
position_ids,
|
||||||
|
"gpt_neox")
|
||||||
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||||
|
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
|
||||||
|
|
||||||
query = torch.cat((query, query_pass), dim=-1)
|
query = torch.cat((query, query_pass), dim=-1)
|
||||||
key = torch.cat((key, key_pass), dim=-1)
|
key = torch.cat((key, key_pass), dim=-1)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue