fix qwen kv cache length (#9998)
This commit is contained in:
		
							parent
							
								
									762adc4f9d
								
							
						
					
					
						commit
						aae1870096
					
				
					 1 changed files with 1 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -192,8 +192,7 @@ def qwen_attention_forward(
 | 
			
		|||
            cache_k, cache_v = layer_past[0], layer_past[1]
 | 
			
		||||
            cache_k = cache_k.transpose(1, 2)
 | 
			
		||||
            cache_v = cache_v.transpose(1, 2)
 | 
			
		||||
            kv_seq_len += cache_k.shape[2]
 | 
			
		||||
            if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if cache_k.stride(1) < kv_seq_len * cache_k.size(3):
 | 
			
		||||
                # allocate new
 | 
			
		||||
                new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                           self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue