fix baichuan (#11606)
This commit is contained in:
parent
bfcdc35b04
commit
e5c0058c0e
1 changed files with 5 additions and 4 deletions
|
|
@ -37,6 +37,7 @@ from torch.nn import functional as F
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ipex_llm.transformers.npu_models.common import merge_linear
|
from ipex_llm.transformers.npu_models.common import merge_linear
|
||||||
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||||
|
|
||||||
|
|
||||||
def merge_mlp(module: torch.nn.Module):
|
def merge_mlp(module: torch.nn.Module):
|
||||||
|
|
@ -85,10 +86,10 @@ def baichuan_attention_fwd(
|
||||||
cos, sin, position_ids)
|
cos, sin, position_ids)
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
if past_key_value is not None:
|
key_states, value_states = update_past_key_value(
|
||||||
# reuse k, v, self_attention
|
past_key_value, key_states, value_states,
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
kv_seq_len, False, "cpu"
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue