optimize qwen rope (#9737)
This commit is contained in:
		
							parent
							
								
									4c032a433e
								
							
						
					
					
						commit
						13ea6330bd
					
				
					 1 changed files with 17 additions and 24 deletions
				
			
		| 
						 | 
				
			
			@ -86,13 +86,23 @@ def qwen_attention_forward(
 | 
			
		|||
    if rotary_pos_emb_list is not None:
 | 
			
		||||
        cur_len = query.shape[1]
 | 
			
		||||
        if len(rotary_pos_emb_list) == 1:
 | 
			
		||||
            rotary_pos_emb = rotary_pos_emb_list[0]
 | 
			
		||||
            rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
			
		||||
            rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
			
		||||
            q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
			
		||||
            # Slice the pos emb for current inference
 | 
			
		||||
            query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
			
		||||
            key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
			
		||||
            if query.device.type == 'xpu':
 | 
			
		||||
                cos, sin = rotary_pos_emb_list[0]
 | 
			
		||||
                cos = cos[:, -cur_len:, :, :]
 | 
			
		||||
                sin = sin[:, -cur_len:, :, :]
 | 
			
		||||
                rot_dim = cos.shape[-1]
 | 
			
		||||
                query_cur = query[..., :rot_dim]
 | 
			
		||||
                key_cur = key[..., :rot_dim]
 | 
			
		||||
                torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur)
 | 
			
		||||
                torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur)
 | 
			
		||||
            else:
 | 
			
		||||
                rotary_pos_emb = rotary_pos_emb_list[0]
 | 
			
		||||
                rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
			
		||||
                rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
			
		||||
                q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
			
		||||
                # Slice the pos emb for current inference
 | 
			
		||||
                query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
			
		||||
                key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
			
		||||
        else:
 | 
			
		||||
            query_list = []
 | 
			
		||||
            key_list = []
 | 
			
		||||
| 
						 | 
				
			
			@ -195,23 +205,6 @@ def qwen_attention_forward(
 | 
			
		|||
                                  None,
 | 
			
		||||
                                  Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
 | 
			
		||||
 | 
			
		||||
        # Remove for efficiency issue on Arc, maybe add later.
 | 
			
		||||
        # if not self.use_cache_quantization and SUPPORT_TORCH2:
 | 
			
		||||
        #     if attention_mask is not None:
 | 
			
		||||
        #         attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
 | 
			
		||||
        #         if causal_mask is not None:
 | 
			
		||||
        #             attention_mask = attention_mask.masked_fill(~causal_mask,
 | 
			
		||||
        #                                                         torch.finfo(query.dtype).min)
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         attention_mask = causal_mask
 | 
			
		||||
        #     attn_output = F.scaled_dot_product_attention(
 | 
			
		||||
        #         query, key, value, attn_mask=attention_mask
 | 
			
		||||
        #     ).transpose(1, 2)
 | 
			
		||||
        #     attn_weight = None
 | 
			
		||||
        # else:
 | 
			
		||||
        #     attn_output, attn_weight = self._attn(
 | 
			
		||||
        #     query, key, value, causal_mask, attention_mask, head_mask
 | 
			
		||||
        # )
 | 
			
		||||
        attn_output, attn_weight = self._attn(
 | 
			
		||||
            query, key, value, causal_mask, attention_mask, head_mask
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue