fix phi3 and minicpmv cpu (#11818)
This commit is contained in:
parent
4e178f0c5d
commit
828ab16537
3 changed files with 14 additions and 4 deletions
|
|
@ -67,3 +67,13 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
|||
)
|
||||
else:
|
||||
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
||||
|
||||
|
||||
def attention_softmax(attn_weights: torch.Tensor, training: bool):
|
||||
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
|
||||
import xe_addons
|
||||
xe_addons.attn_softmax_inplaced(attn_weights)
|
||||
else:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(attn_weights.dtype)
|
||||
return attn_weights
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import math
|
|||
import torch
|
||||
from typing import Optional
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from transformers import AutoProcessor
|
||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
|
@ -47,8 +48,7 @@ def siglip_attention_forward(
|
|||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
import xe_addons
|
||||
xe_addons.attn_softmax_inplaced(attn_weights)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ import torch
|
|||
import warnings
|
||||
from torch import nn
|
||||
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
|
|
@ -184,8 +185,7 @@ def attention_forward(
|
|||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
import xe_addons
|
||||
xe_addons.attn_softmax_inplaced(attn_weights)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
|
|
|
|||
Loading…
Reference in a new issue