diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 674db7e0..8499843d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -130,12 +130,11 @@ def llama_attention_forward_4_31( new_cache_key = torch.empty(bsz, self.num_heads, kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :] - self.kv_cache[0] = new_cache_key new_cache_value = torch.empty(bsz, self.num_heads, kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :] - self.kv_cache[1] = new_cache_value + self.kv_cache = (new_cache_key, new_cache_value) self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states