fix baichuan (#11606)

This commit is contained in:
Zhao Changmin 2024-07-18 09:43:36 +08:00 committed by GitHub
parent bfcdc35b04
commit e5c0058c0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -37,6 +37,7 @@ from torch.nn import functional as F
import importlib
from typing import Optional, Tuple
from ipex_llm.transformers.npu_models.common import merge_linear
from ipex_llm.transformers.models.utils import update_past_key_value
def merge_mlp(module: torch.nn.Module):
@ -85,10 +86,10 @@ def baichuan_attention_fwd(
cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, False, "cpu"
)
past_key_value = (key_states, value_states) if use_cache else None