Fix llama kv cache bug (#8674)
This commit is contained in:
		
							parent
							
								
									59903ea668
								
							
						
					
					
						commit
						3407f87075
					
				
					 1 changed files with 1 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -130,12 +130,11 @@ def llama_attention_forward_4_31(
 | 
			
		|||
            new_cache_key = torch.empty(bsz, self.num_heads,
 | 
			
		||||
                                        kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
 | 
			
		||||
            new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
 | 
			
		||||
            self.kv_cache[0] = new_cache_key
 | 
			
		||||
 | 
			
		||||
            new_cache_value = torch.empty(bsz, self.num_heads,
 | 
			
		||||
                                          kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
 | 
			
		||||
            new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
 | 
			
		||||
            self.kv_cache[1] = new_cache_value
 | 
			
		||||
            self.kv_cache = (new_cache_key, new_cache_value)
 | 
			
		||||
            self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
 | 
			
		||||
        self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue