fix phi3 and minicpmv cpu (#11818)
This commit is contained in:
parent
4e178f0c5d
commit
828ab16537
3 changed files with 14 additions and 4 deletions
|
|
@ -67,3 +67,13 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
def attention_softmax(attn_weights: torch.Tensor, training: bool):
|
||||||
|
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
|
||||||
|
import xe_addons
|
||||||
|
xe_addons.attn_softmax_inplaced(attn_weights)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(attn_weights.dtype)
|
||||||
|
return attn_weights
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
|
from ipex_llm.transformers.models.common import attention_softmax
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
@ -47,8 +48,7 @@ def siglip_attention_forward(
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
import xe_addons
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
xe_addons.attn_softmax_inplaced(attn_weights)
|
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ import torch
|
||||||
import warnings
|
import warnings
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ipex_llm.transformers.models.common import attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
|
|
@ -184,8 +185,7 @@ def attention_forward(
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
import xe_addons
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
xe_addons.attn_softmax_inplaced(attn_weights)
|
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
training=self.training)
|
training=self.training)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue