refactor sd 1.5 and qwen2-vl and fix (#12590)
This commit is contained in:
		
							parent
							
								
									b050368efc
								
							
						
					
					
						commit
						098eb335b2
					
				
					 4 changed files with 23 additions and 58 deletions
				
			
		| 
						 | 
				
			
			@ -75,7 +75,7 @@ def siglip_attention_forward(
 | 
			
		|||
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = scaled_dot_product_attention(
 | 
			
		||||
            query_states, key_states, value_states,
 | 
			
		||||
            query_states, key_states.contiguous(), value_states.contiguous(),
 | 
			
		||||
            attention_mask, False, 1 / math.sqrt(self.head_dim)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -583,8 +583,7 @@ def qwen2_attention_forward(
 | 
			
		|||
                                                             self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if query_states.device.type == 'xpu' \
 | 
			
		||||
            and use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
    if use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,8 +43,9 @@ from typing import Optional, Tuple, Union, List
 | 
			
		|||
import torch
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
| 
						 | 
				
			
			@ -198,7 +199,6 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
                      "unexpected input")
 | 
			
		||||
 | 
			
		||||
    if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        image_num = len(seq_lens) - 1
 | 
			
		||||
        image_size = seq_lens[1] - seq_lens[0]
 | 
			
		||||
        guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
 | 
			
		||||
| 
						 | 
				
			
			@ -209,7 +209,10 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
            v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
			
		||||
            # q, k, v: [image_num, num_heads, image_size, head_dim]
 | 
			
		||||
 | 
			
		||||
            attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
 | 
			
		||||
            attn_output = scaled_dot_product_attention(
 | 
			
		||||
                q, k.contiguous(), v.contiguous(),
 | 
			
		||||
                None, False
 | 
			
		||||
            )
 | 
			
		||||
            attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
 | 
			
		||||
            attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
 | 
			
		||||
            # attn_output: [seq_length, num_heads, head_dim]
 | 
			
		||||
| 
						 | 
				
			
			@ -226,7 +229,10 @@ def qwen2_vision_attention_forward(
 | 
			
		|||
                tmp_q = q[:, :, start_idx:end_idx, :]
 | 
			
		||||
                tmp_k = k[:, :, start_idx:end_idx, :]
 | 
			
		||||
                tmp_v = v[:, :, start_idx:end_idx, :]
 | 
			
		||||
                attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None)
 | 
			
		||||
                attn_output = scaled_dot_product_attention(
 | 
			
		||||
                    tmp_q, tmp_k, tmp_v,
 | 
			
		||||
                    None, False
 | 
			
		||||
                )
 | 
			
		||||
                attn_output = attn_output.permute(0, 2, 1, 3)
 | 
			
		||||
                # attn_output: [1, seq_length, num_heads, head_dim]
 | 
			
		||||
                attn_outputs.append(attn_output)
 | 
			
		||||
| 
						 | 
				
			
			@ -293,42 +299,11 @@ def qwen2_vl_attention_forward(
 | 
			
		|||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                         self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.size(2)
 | 
			
		||||
    if attention_mask is not None:  # no matter the length, we just slice it
 | 
			
		||||
        causal_mask = attention_mask[:, :, :, :kv_seq_len]
 | 
			
		||||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, causal_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, causal_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        if causal_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + causal_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == key_states.size(2)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, -1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,8 +37,8 @@ import torch
 | 
			
		|||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.utils import get_xpu_device_type
 | 
			
		||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
			
		||||
from ipex_llm.transformers.models.common import padding_qkv_hd
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from diffusers.models.attention_processor import Attention
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -110,19 +110,10 @@ class AttnProcessor2_0:
 | 
			
		|||
        if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
 | 
			
		||||
            # padding head_dim 40 to 64
 | 
			
		||||
            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
			
		||||
 | 
			
		||||
            if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
 | 
			
		||||
                import xe_addons
 | 
			
		||||
                hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
 | 
			
		||||
                                                         value.contiguous(), attention_mask)
 | 
			
		||||
            else:
 | 
			
		||||
                scale = 1 / math.sqrt(head_dim)
 | 
			
		||||
                attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
 | 
			
		||||
                if attention_mask is not None:
 | 
			
		||||
                    attn_weights = attn_weights + attention_mask
 | 
			
		||||
                attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
                hidden_states = torch.matmul(attn_weights, value)
 | 
			
		||||
 | 
			
		||||
            hidden_states = scaled_dot_product_attention(
 | 
			
		||||
                query, key.contiguous(), value.contiguous(),
 | 
			
		||||
                attention_mask, False, 1 / math.sqrt(head_dim)
 | 
			
		||||
            )
 | 
			
		||||
            hidden_states = hidden_states[:, :, :, :head_dim]
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue