optimize qwen2_vl multiple image input or video input (#12487)
This commit is contained in:
		
							parent
							
								
									c59284418c
								
							
						
					
					
						commit
						5629fdd518
					
				
					 1 changed files with 50 additions and 21 deletions
				
			
		| 
						 | 
				
			
			@ -191,37 +191,66 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
                                              ).permute(1, 0, 2, 3).unbind(0)
 | 
			
		||||
    q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
 | 
			
		||||
    k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
 | 
			
		||||
    # q, k, v: [seq_length, num_heads, head_dim]
 | 
			
		||||
 | 
			
		||||
    q = q.transpose(0, 1)
 | 
			
		||||
    k = k.transpose(0, 1)
 | 
			
		||||
    v = v.transpose(0, 1)
 | 
			
		||||
    seq_lens = cu_seqlens.tolist()
 | 
			
		||||
    invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length,
 | 
			
		||||
                      "unexpected input")
 | 
			
		||||
 | 
			
		||||
    if len(cu_seqlens) == 2 and cu_seqlens.tolist() == [0, seq_length]:
 | 
			
		||||
        attention_mask = None
 | 
			
		||||
    if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        image_num = len(seq_lens) - 1
 | 
			
		||||
        image_size = seq_lens[1] - seq_lens[0]
 | 
			
		||||
        guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
 | 
			
		||||
                                        dtype=cu_seqlens.dtype, device=cu_seqlens.device)
 | 
			
		||||
        if (guessed_seq_lens == cu_seqlens).all():
 | 
			
		||||
            q = q.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            k = k.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            # q, k, v: [image_num, num_heads, image_size, head_dim]
 | 
			
		||||
 | 
			
		||||
            attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
 | 
			
		||||
            attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
 | 
			
		||||
            attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
 | 
			
		||||
            # attn_output: [seq_length, num_heads, head_dim]
 | 
			
		||||
        else:
 | 
			
		||||
            q = q.transpose(0, 1).unsqueeze(0)
 | 
			
		||||
            k = k.transpose(0, 1).unsqueeze(0).contiguous()
 | 
			
		||||
            v = v.transpose(0, 1).unsqueeze(0).contiguous()
 | 
			
		||||
            # q, k, v: [1, num_heads, seq_length, head_dim]
 | 
			
		||||
 | 
			
		||||
            attn_outputs = []
 | 
			
		||||
            for i in range(image_num):
 | 
			
		||||
                start_idx = seq_lens[i]
 | 
			
		||||
                end_idx = seq_lens[i + 1]
 | 
			
		||||
                tmp_q = q[:, :, start_idx:end_idx, :]
 | 
			
		||||
                tmp_k = k[:, :, start_idx:end_idx, :]
 | 
			
		||||
                tmp_v = v[:, :, start_idx:end_idx, :]
 | 
			
		||||
                attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None)
 | 
			
		||||
                attn_output = attn_output.permute(0, 2, 1, 3)
 | 
			
		||||
                # attn_output: [1, seq_length, num_heads, head_dim]
 | 
			
		||||
                attn_outputs.append(attn_output)
 | 
			
		||||
            attn_output = torch.cat(attn_outputs, dim=1).squeeze(0)
 | 
			
		||||
            # attn_output: [seq_length, num_heads, head_dim]
 | 
			
		||||
    else:
 | 
			
		||||
        attention_mask = torch.full(
 | 
			
		||||
            [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
 | 
			
		||||
        )
 | 
			
		||||
        for i in range(1, len(cu_seqlens)):
 | 
			
		||||
            attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
 | 
			
		||||
                           cu_seqlens[i - 1]:cu_seqlens[i]] = 0
 | 
			
		||||
        for i in range(1, len(seq_lens)):
 | 
			
		||||
            attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0
 | 
			
		||||
 | 
			
		||||
        q = q.transpose(0, 1)
 | 
			
		||||
        k = k.transpose(0, 1)
 | 
			
		||||
        v = v.transpose(0, 1)
 | 
			
		||||
        # q, k, v: [num_heads, seq_length, head_dim]
 | 
			
		||||
 | 
			
		||||
    if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        q = q.unsqueeze(0)
 | 
			
		||||
        k = k.unsqueeze(0)
 | 
			
		||||
        v = v.unsqueeze(0)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attention_mask = attention_mask.unsqueeze(0)
 | 
			
		||||
        attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), attention_mask)
 | 
			
		||||
        attn_output = attn_output.squeeze(0)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, v)
 | 
			
		||||
    attn_output = attn_output.transpose(0, 1)
 | 
			
		||||
        attn_output = attn_output.transpose(0, 1)
 | 
			
		||||
        # attn_output: [seq_length, num_heads, head_dim]
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.reshape(seq_length, -1)
 | 
			
		||||
    attn_output = self.proj(attn_output)
 | 
			
		||||
    return attn_output
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue