fix baichuan (#11606)
This commit is contained in:
parent
bfcdc35b04
commit
e5c0058c0e
1 changed files with 5 additions and 4 deletions
|
|
@ -37,6 +37,7 @@ from torch.nn import functional as F
|
|||
import importlib
|
||||
from typing import Optional, Tuple
|
||||
from ipex_llm.transformers.npu_models.common import merge_linear
|
||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||
|
||||
|
||||
def merge_mlp(module: torch.nn.Module):
|
||||
|
|
@ -85,10 +86,10 @@ def baichuan_attention_fwd(
|
|||
cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
key_states, value_states = update_past_key_value(
|
||||
past_key_value, key_states, value_states,
|
||||
kv_seq_len, False, "cpu"
|
||||
)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue