Fix llama kv cache bug (#8674)
This commit is contained in:
parent
59903ea668
commit
3407f87075
1 changed files with 1 additions and 2 deletions
|
|
@ -130,12 +130,11 @@ def llama_attention_forward_4_31(
|
||||||
new_cache_key = torch.empty(bsz, self.num_heads,
|
new_cache_key = torch.empty(bsz, self.num_heads,
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
||||||
new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
|
new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
|
||||||
self.kv_cache[0] = new_cache_key
|
|
||||||
|
|
||||||
new_cache_value = torch.empty(bsz, self.num_heads,
|
new_cache_value = torch.empty(bsz, self.num_heads,
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
||||||
new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
|
new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
|
||||||
self.kv_cache[1] = new_cache_value
|
self.kv_cache = (new_cache_key, new_cache_value)
|
||||||
self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
|
||||||
self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states
|
self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue