parent
							
								
									2bbd8a1548
								
							
						
					
					
						commit
						23e33a0ca1
					
				
					 1 changed files with 9 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -126,7 +126,13 @@ def qwen_attention_forward_vl(
 | 
			
		|||
        query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
 | 
			
		||||
        key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
 | 
			
		||||
        value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
 | 
			
		||||
        # TODO: speed up
 | 
			
		||||
        # mixed_x_layer = self.c_attn(hidden_states)
 | 
			
		||||
        # query, key, value = mixed_x_layer.split(self.split_size, dim=2)
 | 
			
		||||
 | 
			
		||||
        # query = self._split_heads(query, self.num_heads, self.head_dim)
 | 
			
		||||
        # key = self._split_heads(key, self.num_heads, self.head_dim)
 | 
			
		||||
        # value = self._split_heads(value, self.num_heads, self.head_dim)
 | 
			
		||||
        if rotary_pos_emb is not None:
 | 
			
		||||
            cur_len = query.shape[1]
 | 
			
		||||
            rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
			
		||||
| 
						 | 
				
			
			@ -183,7 +189,8 @@ def qwen_attention_forward_vl(
 | 
			
		|||
        present = None
 | 
			
		||||
 | 
			
		||||
    if decoding_fast_path:
 | 
			
		||||
        query = query.transpose(1, 2) # change to (bsz, q_len, num_heads, head_dim)
 | 
			
		||||
        # change to (bsz, q_len, num_heads, head_dim)
 | 
			
		||||
        query = query.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    if self.use_logn_attn and not self.training:
 | 
			
		||||
        if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
 | 
			
		||||
| 
						 | 
				
			
			@ -196,7 +203,7 @@ def qwen_attention_forward_vl(
 | 
			
		|||
    query = query.permute(0, 2, 1, 3)
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
    use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
            use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query,
 | 
			
		||||
                                                    key,
 | 
			
		||||
| 
						 | 
				
			
			@ -209,7 +216,6 @@ def qwen_attention_forward_vl(
 | 
			
		|||
            query, key, value, registered_causal_mask, attention_mask, head_mask
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    context_layer = self._merge_heads(
 | 
			
		||||
        attn_output, self.num_heads, self.head_dim
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue