[LLM] Fix Qwen causal_mask and attention_mask size mismatching (#9600)
* Fix #9582 , caused by Qwen modified modeling_qwen.py 7f62181c94 (d2h-049182)
			
			
This commit is contained in:
		
							parent
							
								
									b721138132
								
							
						
					
					
						commit
						65934c9f4f
					
				
					 1 changed files with 6 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -151,7 +151,8 @@ def qwen_attention_forward(
 | 
			
		|||
    else:
 | 
			
		||||
        present = None
 | 
			
		||||
 | 
			
		||||
    if self.use_logn_attn and not self.training:
 | 
			
		||||
    key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
 | 
			
		||||
    if key_size > self.seq_length and self.use_logn_attn and not self.training:
 | 
			
		||||
        if self.use_cache_quantization:
 | 
			
		||||
            seq_start = key[0].size(2) - query.size(1)
 | 
			
		||||
            seq_end = key[0].size(2)
 | 
			
		||||
| 
						 | 
				
			
			@ -174,9 +175,10 @@ def qwen_attention_forward(
 | 
			
		|||
        context_layer = context_layer.flatten(2, 3).contiguous()
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        registered_causal_mask = torch.tril(
 | 
			
		||||
            torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
 | 
			
		||||
        ).view(1, 1, key.size(1), key.size(1))
 | 
			
		||||
        if query.size(1) == key_size:
 | 
			
		||||
            registered_causal_mask = torch.tril(
 | 
			
		||||
                torch.ones((key_size, key_size), dtype=torch.bool, device=key.device)
 | 
			
		||||
            ).view(1, 1, key_size, key_size)
 | 
			
		||||
        query = query.permute(0, 2, 1, 3)
 | 
			
		||||
        if not self.use_cache_quantization:
 | 
			
		||||
            key = key.permute(0, 2, 1, 3)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue