Fix llama kv cache bug (#8674)
This commit is contained in:
parent
59903ea668
commit
3407f87075
1 changed files with 1 additions and 2 deletions
|
|
@ -130,12 +130,11 @@ def llama_attention_forward_4_31(
|
|||
new_cache_key = torch.empty(bsz, self.num_heads,
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
||||
new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
|
||||
self.kv_cache[0] = new_cache_key
|
||||
|
||||
new_cache_value = torch.empty(bsz, self.num_heads,
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
||||
new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
|
||||
self.kv_cache[1] = new_cache_value
|
||||
self.kv_cache = (new_cache_key, new_cache_value)
|
||||
self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
|
||||
self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states
|
||||
|
|
|
|||
Loading…
Reference in a new issue