diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py index f76c6798..7a2fbf02 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan.py @@ -37,6 +37,7 @@ from torch.nn import functional as F import importlib from typing import Optional, Tuple from ipex_llm.transformers.npu_models.common import merge_linear +from ipex_llm.transformers.models.utils import update_past_key_value def merge_mlp(module: torch.nn.Module): @@ -85,10 +86,10 @@ def baichuan_attention_fwd( cos, sin, position_ids) # [bsz, nh, t, hd] - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, False, "cpu" + ) past_key_value = (key_states, value_states) if use_cache else None