diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index 847f43b9..9a9618bf 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -114,6 +114,11 @@ def mistral_attention_forward( dtype=cache_k.dtype, device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) elif use_cache: