parent
2bbd8a1548
commit
23e33a0ca1
1 changed files with 9 additions and 3 deletions
|
|
@ -126,7 +126,13 @@ def qwen_attention_forward_vl(
|
|||
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
# TODO: speed up
|
||||
# mixed_x_layer = self.c_attn(hidden_states)
|
||||
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||
|
||||
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
if rotary_pos_emb is not None:
|
||||
cur_len = query.shape[1]
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
|
|
@ -183,7 +189,8 @@ def qwen_attention_forward_vl(
|
|||
present = None
|
||||
|
||||
if decoding_fast_path:
|
||||
query = query.transpose(1, 2) # change to (bsz, q_len, num_heads, head_dim)
|
||||
# change to (bsz, q_len, num_heads, head_dim)
|
||||
query = query.transpose(1, 2)
|
||||
|
||||
if self.use_logn_attn and not self.training:
|
||||
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
|
||||
|
|
@ -209,7 +216,6 @@ def qwen_attention_forward_vl(
|
|||
query, key, value, registered_causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
||||
|
||||
context_layer = self._merge_heads(
|
||||
attn_output, self.num_heads, self.head_dim
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue