refactor sd 1.5 and qwen2-vl and fix (#12590)
This commit is contained in:
		
							parent
							
								
									b050368efc
								
							
						
					
					
						commit
						098eb335b2
					
				
					 4 changed files with 23 additions and 58 deletions
				
			
		| 
						 | 
					@ -75,7 +75,7 @@ def siglip_attention_forward(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
        attn_output = scaled_dot_product_attention(
 | 
					        attn_output = scaled_dot_product_attention(
 | 
				
			||||||
            query_states, key_states, value_states,
 | 
					            query_states, key_states.contiguous(), value_states.contiguous(),
 | 
				
			||||||
            attention_mask, False, 1 / math.sqrt(self.head_dim)
 | 
					            attention_mask, False, 1 / math.sqrt(self.head_dim)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -583,8 +583,7 @@ def qwen2_attention_forward(
 | 
				
			||||||
                                                             self.layer_idx, None)
 | 
					                                                             self.layer_idx, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = None
 | 
					    attn_weights = None
 | 
				
			||||||
    if query_states.device.type == 'xpu' \
 | 
					    if use_flash_attention(query_states, key_states, attention_mask):
 | 
				
			||||||
            and use_flash_attention(query_states, key_states, attention_mask):
 | 
					 | 
				
			||||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
					        # repeat k/v heads if n_kv_heads < n_heads
 | 
				
			||||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,8 +43,9 @@ from typing import Optional, Tuple, Union, List
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
 | 
					from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
					from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
					from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
				
			||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
					from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
| 
						 | 
					@ -198,7 +199,6 @@ def qwen2_vision_attention_forward(
 | 
				
			||||||
                      "unexpected input")
 | 
					                      "unexpected input")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
 | 
					    if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
 | 
				
			||||||
        import xe_addons
 | 
					 | 
				
			||||||
        image_num = len(seq_lens) - 1
 | 
					        image_num = len(seq_lens) - 1
 | 
				
			||||||
        image_size = seq_lens[1] - seq_lens[0]
 | 
					        image_size = seq_lens[1] - seq_lens[0]
 | 
				
			||||||
        guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
 | 
					        guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
 | 
				
			||||||
| 
						 | 
					@ -209,7 +209,10 @@ def qwen2_vision_attention_forward(
 | 
				
			||||||
            v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
					            v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
 | 
				
			||||||
            # q, k, v: [image_num, num_heads, image_size, head_dim]
 | 
					            # q, k, v: [image_num, num_heads, image_size, head_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
 | 
					            attn_output = scaled_dot_product_attention(
 | 
				
			||||||
 | 
					                q, k.contiguous(), v.contiguous(),
 | 
				
			||||||
 | 
					                None, False
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
 | 
					            attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
 | 
				
			||||||
            attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
 | 
					            attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
 | 
				
			||||||
            # attn_output: [seq_length, num_heads, head_dim]
 | 
					            # attn_output: [seq_length, num_heads, head_dim]
 | 
				
			||||||
| 
						 | 
					@ -226,7 +229,10 @@ def qwen2_vision_attention_forward(
 | 
				
			||||||
                tmp_q = q[:, :, start_idx:end_idx, :]
 | 
					                tmp_q = q[:, :, start_idx:end_idx, :]
 | 
				
			||||||
                tmp_k = k[:, :, start_idx:end_idx, :]
 | 
					                tmp_k = k[:, :, start_idx:end_idx, :]
 | 
				
			||||||
                tmp_v = v[:, :, start_idx:end_idx, :]
 | 
					                tmp_v = v[:, :, start_idx:end_idx, :]
 | 
				
			||||||
                attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None)
 | 
					                attn_output = scaled_dot_product_attention(
 | 
				
			||||||
 | 
					                    tmp_q, tmp_k, tmp_v,
 | 
				
			||||||
 | 
					                    None, False
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                attn_output = attn_output.permute(0, 2, 1, 3)
 | 
					                attn_output = attn_output.permute(0, 2, 1, 3)
 | 
				
			||||||
                # attn_output: [1, seq_length, num_heads, head_dim]
 | 
					                # attn_output: [1, seq_length, num_heads, head_dim]
 | 
				
			||||||
                attn_outputs.append(attn_output)
 | 
					                attn_outputs.append(attn_output)
 | 
				
			||||||
| 
						 | 
					@ -293,42 +299,11 @@ def qwen2_vl_attention_forward(
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                         self.layer_idx, None)
 | 
					                                                         self.layer_idx, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    kv_seq_len = key_states.size(2)
 | 
					 | 
				
			||||||
    if attention_mask is not None:  # no matter the length, we just slice it
 | 
					 | 
				
			||||||
        causal_mask = attention_mask[:, :, :, :kv_seq_len]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    attn_weights = None
 | 
					    attn_weights = None
 | 
				
			||||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
					    attn_output = scaled_dot_product_attention(
 | 
				
			||||||
        import xe_addons
 | 
					        query_states, key_states, value_states,
 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					        attention_mask, q_len == key_states.size(2)
 | 
				
			||||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
 | 
					    )
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
 | 
					 | 
				
			||||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
					 | 
				
			||||||
        import xe_addons
 | 
					 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					 | 
				
			||||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
					 | 
				
			||||||
                                                   value_states, causal_mask)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
					 | 
				
			||||||
                                               value_states, causal_mask)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					 | 
				
			||||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
					 | 
				
			||||||
                                                            query_states.dtype)
 | 
					 | 
				
			||||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
					 | 
				
			||||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					 | 
				
			||||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_weights = torch.matmul(query_states,
 | 
					 | 
				
			||||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if causal_mask is not None:
 | 
					 | 
				
			||||||
            attn_weights = attn_weights + causal_mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # upcast attention to fp32
 | 
					 | 
				
			||||||
        attn_weights = attention_softmax(attn_weights)
 | 
					 | 
				
			||||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
    attn_output = attn_output.reshape(bsz, q_len, -1)
 | 
					    attn_output = attn_output.reshape(bsz, q_len, -1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,8 +37,8 @@ import torch
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.transformers.utils import get_xpu_device_type
 | 
					from ipex_llm.transformers.utils import get_xpu_device_type
 | 
				
			||||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
 | 
					from ipex_llm.transformers.models.common import padding_qkv_hd
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
					from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
				
			||||||
from diffusers.models.attention_processor import Attention
 | 
					from diffusers.models.attention_processor import Attention
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -110,19 +110,10 @@ class AttnProcessor2_0:
 | 
				
			||||||
        if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
 | 
					        if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
 | 
				
			||||||
            # padding head_dim 40 to 64
 | 
					            # padding head_dim 40 to 64
 | 
				
			||||||
            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
					            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
				
			||||||
 | 
					            hidden_states = scaled_dot_product_attention(
 | 
				
			||||||
            if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
 | 
					                query, key.contiguous(), value.contiguous(),
 | 
				
			||||||
                import xe_addons
 | 
					                attention_mask, False, 1 / math.sqrt(head_dim)
 | 
				
			||||||
                hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
 | 
					            )
 | 
				
			||||||
                                                         value.contiguous(), attention_mask)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                scale = 1 / math.sqrt(head_dim)
 | 
					 | 
				
			||||||
                attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
 | 
					 | 
				
			||||||
                if attention_mask is not None:
 | 
					 | 
				
			||||||
                    attn_weights = attn_weights + attention_mask
 | 
					 | 
				
			||||||
                attn_weights = attention_softmax(attn_weights)
 | 
					 | 
				
			||||||
                hidden_states = torch.matmul(attn_weights, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            hidden_states = hidden_states[:, :, :, :head_dim]
 | 
					            hidden_states = hidden_states[:, :, :, :head_dim]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            hidden_states = torch.nn.functional.scaled_dot_product_attention(
 | 
					            hidden_states = torch.nn.functional.scaled_dot_product_attention(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue