Fix qwen-vl style (#10633)

* update

* update
This commit is contained in:
Jiao Wang 2024-04-02 18:41:38 -07:00 committed by GitHub
parent 2bbd8a1548
commit 23e33a0ca1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -126,7 +126,13 @@ def qwen_attention_forward_vl(
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
# TODO: speed up
# mixed_x_layer = self.c_attn(hidden_states)
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb is not None:
cur_len = query.shape[1]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
@ -183,7 +189,8 @@ def qwen_attention_forward_vl(
present = None
if decoding_fast_path:
query = query.transpose(1, 2) # change to (bsz, q_len, num_heads, head_dim)
# change to (bsz, q_len, num_heads, head_dim)
query = query.transpose(1, 2)
if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
@ -196,7 +203,7 @@ def qwen_attention_forward_vl(
query = query.permute(0, 2, 1, 3)
if not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query,
key,
@ -209,7 +216,6 @@ def qwen_attention_forward_vl(
query, key, value, registered_causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)