small refactor and fix (#13101)

This commit is contained in:
Yishuo Wang 2025-04-22 14:45:31 +08:00 committed by GitHub
parent 14cd613fe1
commit 908fdb982e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 21 additions and 13 deletions

View file

@ -775,12 +775,15 @@ class _BaseAutoModelClass:
model)
torch.distributed.barrier()
try:
# add lookup_generate to loaded model
from .lookup import lookup_generate
import types
model.lookup_generate = types.MethodType(lookup_generate, model)
if model.config.model_type == "minicpmv" and hasattr(model, 'llm'):
model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm)
except ImportError as e:
pass
return model

View file

@ -32,7 +32,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import rotate_half, use_fuse_moe
def padding_mla_v_hd(module: torch.nn.Module):
@ -291,11 +291,8 @@ def fuse_gate_forward(self, x: torch.Tensor):
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
if (
x.device.type == "xpu"
and x.dtype in [torch.float, torch.half]
and self.experts[0].down_proj.qtype == 2
):
qtype = self.experts[0].down_proj.qtype
if use_fuse_moe(x, qtype):
if getattr(self, "gates", None) is None:
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
@ -310,7 +307,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
import xe_linear
final_out = xe_linear.moe_forward_vec(
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
x.size(-1), self.experts[0].intermediate_size, 2
x.size(-1), self.experts[0].intermediate_size, qtype
)
else:
idxs = topk_ids.flatten().tolist()

View file

@ -394,3 +394,11 @@ def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
new_sin = sin.contiguous()
cos.set_(new_cos)
sin.set_(new_sin)
def use_fuse_moe(hidden_states: torch.Tensor, qtype: int):
return (
hidden_states.device.type == "xpu"
and hidden_states.dtype in [torch.float, torch.half]
and qtype in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["woq_int4"]]
)