fix phi3 and minicpmv cpu (#11818)

This commit is contained in:
Yishuo Wang 2024-08-15 17:43:29 +08:00 committed by GitHub
parent 4e178f0c5d
commit 828ab16537
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 14 additions and 4 deletions

View file

@ -67,3 +67,13 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
)
else:
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
def attention_softmax(attn_weights: torch.Tensor, training: bool):
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(attn_weights.dtype)
return attn_weights

View file

@ -19,6 +19,7 @@ import math
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import attention_softmax
from transformers import AutoProcessor
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
@ -47,8 +48,7 @@ def siglip_attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

View file

@ -37,6 +37,7 @@ import torch
import warnings
from torch import nn
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
@ -184,8 +185,7 @@ def attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)