From 828ab1653751ac0c4d76f9b89bc5d1fe4b9d5db9 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 15 Aug 2024 17:43:29 +0800 Subject: [PATCH] fix phi3 and minicpmv cpu (#11818) --- python/llm/src/ipex_llm/transformers/models/common.py | 10 ++++++++++ .../llm/src/ipex_llm/transformers/models/minicpmv.py | 4 ++-- python/llm/src/ipex_llm/transformers/models/phi3.py | 4 ++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index f3dab652..215232e4 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -67,3 +67,13 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor): ) else: return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x)) + + +def attention_softmax(attn_weights: torch.Tensor, training: bool): + if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training: + import xe_addons + xe_addons.attn_softmax_inplaced(attn_weights) + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(attn_weights.dtype) + return attn_weights diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 15bf61d4..f25801be 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -19,6 +19,7 @@ import math import torch from typing import Optional from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import attention_softmax from transformers import AutoProcessor from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor @@ -47,8 +48,7 @@ def siglip_attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - import xe_addons - xe_addons.attn_softmax_inplaced(attn_weights) + attn_weights = attention_softmax(attn_weights, self.training) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index c3e73b60..04ed59af 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -37,6 +37,7 @@ import torch import warnings from torch import nn +from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal @@ -184,8 +185,7 @@ def attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - import xe_addons - xe_addons.attn_softmax_inplaced(attn_weights) + attn_weights = attention_softmax(attn_weights, self.training) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)