fix qwen2 vl (#12798)
This commit is contained in:
parent
3fee838b14
commit
e4ceb722b6
1 changed files with 7 additions and 6 deletions
|
|
@ -200,15 +200,16 @@ def qwen2_vision_attention_forward(
|
||||||
invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length,
|
invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length,
|
||||||
"unexpected input")
|
"unexpected input")
|
||||||
|
|
||||||
if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
|
head_dim = q.size(-1)
|
||||||
|
if use_sdp_non_causal(head_dim, q.device, q.dtype):
|
||||||
image_num = len(seq_lens) - 1
|
image_num = len(seq_lens) - 1
|
||||||
image_size = seq_lens[1] - seq_lens[0]
|
image_size = seq_lens[1] - seq_lens[0]
|
||||||
guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
|
guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
|
||||||
dtype=cu_seqlens.dtype, device=cu_seqlens.device)
|
dtype=cu_seqlens.dtype, device=cu_seqlens.device)
|
||||||
if (guessed_seq_lens == cu_seqlens).all():
|
if (guessed_seq_lens == cu_seqlens).all():
|
||||||
q = q.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
q = q.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
|
||||||
k = k.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
k = k.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
|
||||||
v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
v = v.view(image_num, image_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
|
||||||
# q, k, v: [image_num, num_heads, image_size, head_dim]
|
# q, k, v: [image_num, num_heads, image_size, head_dim]
|
||||||
|
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(
|
||||||
|
|
@ -216,7 +217,7 @@ def qwen2_vision_attention_forward(
|
||||||
None, False
|
None, False
|
||||||
)
|
)
|
||||||
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
||||||
attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
|
attn_output = attn_output.view(seq_length, self.num_heads, head_dim)
|
||||||
# attn_output: [seq_length, num_heads, head_dim]
|
# attn_output: [seq_length, num_heads, head_dim]
|
||||||
else:
|
else:
|
||||||
q = q.transpose(0, 1).unsqueeze(0)
|
q = q.transpose(0, 1).unsqueeze(0)
|
||||||
|
|
@ -252,7 +253,7 @@ def qwen2_vision_attention_forward(
|
||||||
v = v.transpose(0, 1)
|
v = v.transpose(0, 1)
|
||||||
# q, k, v: [num_heads, seq_length, head_dim]
|
# q, k, v: [num_heads, seq_length, head_dim]
|
||||||
|
|
||||||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(head_dim)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
attn_weights = attention_softmax(attn_weights)
|
attn_weights = attention_softmax(attn_weights)
|
||||||
attn_output = torch.matmul(attn_weights, v)
|
attn_output = torch.matmul(attn_weights, v)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue