Fix llama kv cache bug (#8674)

This commit is contained in:
Yang Wang 2023-08-04 08:54:55 +08:00 committed by GitHub
parent 59903ea668
commit 3407f87075

View file

@ -130,12 +130,11 @@ def llama_attention_forward_4_31(
new_cache_key = torch.empty(bsz, self.num_heads, new_cache_key = torch.empty(bsz, self.num_heads,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :] new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
self.kv_cache[0] = new_cache_key
new_cache_value = torch.empty(bsz, self.num_heads, new_cache_value = torch.empty(bsz, self.num_heads,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :] new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
self.kv_cache[1] = new_cache_value self.kv_cache = (new_cache_key, new_cache_value)
self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states