parent
2bbd8a1548
commit
23e33a0ca1
1 changed files with 9 additions and 3 deletions
|
|
@ -126,7 +126,13 @@ def qwen_attention_forward_vl(
|
||||||
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
# TODO: speed up
|
||||||
|
# mixed_x_layer = self.c_attn(hidden_states)
|
||||||
|
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||||
|
|
||||||
|
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
|
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
|
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
cur_len = query.shape[1]
|
cur_len = query.shape[1]
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||||
|
|
@ -183,7 +189,8 @@ def qwen_attention_forward_vl(
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
if decoding_fast_path:
|
if decoding_fast_path:
|
||||||
query = query.transpose(1, 2) # change to (bsz, q_len, num_heads, head_dim)
|
# change to (bsz, q_len, num_heads, head_dim)
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
|
||||||
if self.use_logn_attn and not self.training:
|
if self.use_logn_attn and not self.training:
|
||||||
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
|
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
|
||||||
|
|
@ -209,7 +216,6 @@ def qwen_attention_forward_vl(
|
||||||
query, key, value, registered_causal_mask, attention_mask, head_mask
|
query, key, value, registered_causal_mask, attention_mask, head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
context_layer = self._merge_heads(
|
context_layer = self._merge_heads(
|
||||||
attn_output, self.num_heads, self.head_dim
|
attn_output, self.num_heads, self.head_dim
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue