fix qwen2 vl (#12798)
This commit is contained in:
		
							parent
							
								
									3fee838b14
								
							
						
					
					
						commit
						e4ceb722b6
					
				
					 1 changed files with 7 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -200,15 +200,16 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
    invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length,
 | 
			
		||||
                      "unexpected input")
 | 
			
		||||
 | 
			
		||||
    if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
 | 
			
		||||
    head_dim = q.size(-1)
 | 
			
		||||
    if use_sdp_non_causal(head_dim, q.device, q.dtype):
 | 
			
		||||
        image_num = len(seq_lens) - 1
 | 
			
		||||
        image_size = seq_lens[1] - seq_lens[0]
 | 
			
		||||
        guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
 | 
			
		||||
                                        dtype=cu_seqlens.dtype, device=cu_seqlens.device)
 | 
			
		||||
        if (guessed_seq_lens == cu_seqlens).all():
 | 
			
		||||
            q = q.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            k = k.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            q = q.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            k = k.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            v = v.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            # q, k, v: [image_num, num_heads, image_size, head_dim]
 | 
			
		||||
 | 
			
		||||
            attn_output = scaled_dot_product_attention(
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +217,7 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
                None, False
 | 
			
		||||
            )
 | 
			
		||||
            attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
 | 
			
		||||
            attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
 | 
			
		||||
            attn_output = attn_output.view(seq_length, self.num_heads, head_dim)
 | 
			
		||||
            # attn_output: [seq_length, num_heads, head_dim]
 | 
			
		||||
        else:
 | 
			
		||||
            q = q.transpose(0, 1).unsqueeze(0)
 | 
			
		||||
| 
						 | 
				
			
			@ -252,7 +253,7 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
        v = v.transpose(0, 1)
 | 
			
		||||
        # q, k, v: [num_heads, seq_length, head_dim]
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
 | 
			
		||||
        attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(head_dim)
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, v)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue