diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 8b0ba92f..3e332878 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -51,6 +51,69 @@ def merge_qkv_base(module: torch.nn.Module, attention_class): del module.q_proj, module.k_proj, module.v_proj +def padding_linear_hd(linear: torch.nn.Linear, + old_head_dim: int, new_head_dim: int) -> torch.nn.Linear: + in_features, out_features = linear.in_features, linear.out_features + + weight = linear.weight.data + weight = weight.view(-1, old_head_dim, in_features) + new_weight = torch.empty([weight.size(0), new_head_dim, in_features], + dtype=weight.dtype, device=weight.device) + new_weight[:, :old_head_dim, :] = weight + new_weight[:, old_head_dim:, :] = 0 + new_weight = new_weight.view(-1, in_features) + if linear.bias is not None: + bias = linear.bias.data + bias = bias.view(-1, old_head_dim) + new_bias = torch.empty([bias.size(0), new_head_dim], + dtype=bias.dtype, device=bias.device) + new_bias[:, :old_head_dim] = bias + new_bias[:, old_head_dim:] = 0 + new_bias = new_bias.flatten() + + new_linear = torch.nn.Linear(0, 0, bias=True) + new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False) + else: + new_linear = torch.nn.Linear(0, 0, bias=False) + new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False) + new_linear.in_features = new_weight.size(1) + new_linear.out_features = new_weight.size(0) + return new_linear + + +def padding_attention_hd_base(module: torch.nn.Module, attention_class, + old_head_dim: int, new_head_dim: int): + if ( + isinstance(attention_class, str) and module.__class__.__name__ == attention_class + or not isinstance(attention_class, str) and isinstance(module, attention_class) + ) and module.head_dim == old_head_dim: + module.q_proj = padding_linear_hd(module.q_proj, old_head_dim, new_head_dim) + module.k_proj = padding_linear_hd(module.k_proj, old_head_dim, new_head_dim) + module.v_proj = padding_linear_hd(module.v_proj, old_head_dim, new_head_dim) + module.head_dim = new_head_dim + module.old_head_dim = old_head_dim + + +def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int): + bsz, num_heads, seq_len, head_dim = states.size() + if head_dim == old_head_dim and old_head_dim < new_head_dim: + new_states = torch.empty([bsz, num_heads, seq_len, new_head_dim], + dtype=states.dtype, device=states.device) + new_states[:, :, :, :old_head_dim] = states + new_states[:, :, :, old_head_dim:] = 0 + return new_states + return states + + +def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + old_head_dim: int, new_head_dim: int): + return ( + padding_states_hd(q, old_head_dim, new_head_dim), + padding_states_hd(k, old_head_dim, new_head_dim), + padding_states_hd(v, old_head_dim, new_head_dim), + ) + + def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor): from ipex_llm.transformers.models.utils import mlp_fusion_check x_2d = x.view(-1, x.size(-1)) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 9cad8fc8..0bb0b643 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -26,8 +26,9 @@ import torch from threading import Thread from typing import Optional, List from torch.nn.functional import linear -from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd from ipex_llm.transformers.models.common import attention_softmax +from ipex_llm.transformers.models.utils import use_sdp_non_causal from transformers import AutoProcessor, TextIteratorStreamer from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor @@ -52,14 +53,28 @@ def siglip_attention_forward( qkv = qkv.transpose(1, 2) query_states, key_states, value_states = qkv.chunk(3, dim=1) - attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + query_states, key_states, value_states = padding_qkv_hd( + query_states, key_states, value_states, + 72, 80 + ) - attn_weights = attention_softmax(attn_weights) + if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype): + import xe_addons + attn_weights = None + attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(), + value_states.contiguous(), attention_mask) + else: + attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = attention_softmax(attn_weights) + + attn_weights = torch.nn.functional.dropout(attn_weights, + p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output[:, :, :, :self.head_dim] attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)