diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 971e4349..aa7d89c3 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -775,12 +775,15 @@ class _BaseAutoModelClass: model) torch.distributed.barrier() - # add lookup_generate to loaded model - from .lookup import lookup_generate - import types - model.lookup_generate = types.MethodType(lookup_generate, model) - if model.config.model_type == "minicpmv" and hasattr(model, 'llm'): - model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm) + try: + # add lookup_generate to loaded model + from .lookup import lookup_generate + import types + model.lookup_generate = types.MethodType(lookup_generate, model) + if model.config.model_type == "minicpmv" and hasattr(model, 'llm'): + model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm) + except ImportError as e: + pass return model diff --git a/python/llm/src/ipex_llm/transformers/models/deepseek.py b/python/llm/src/ipex_llm/transformers/models/deepseek.py index e4d1a033..fc61a78e 100644 --- a/python/llm/src/ipex_llm/transformers/models/deepseek.py +++ b/python/llm/src/ipex_llm/transformers/models/deepseek.py @@ -32,7 +32,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.kv import DynamicNormalCache from ipex_llm.transformers.models.common import padding_mla_v_hd_base from ipex_llm.transformers.models.common import scaled_dot_product_attention -from ipex_llm.transformers.models.utils import rotate_half +from ipex_llm.transformers.models.utils import rotate_half, use_fuse_moe def padding_mla_v_hd(module: torch.nn.Module): @@ -291,11 +291,8 @@ def fuse_gate_forward(self, x: torch.Tensor): def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor): - if ( - x.device.type == "xpu" - and x.dtype in [torch.float, torch.half] - and self.experts[0].down_proj.qtype == 2 - ): + qtype = self.experts[0].down_proj.qtype + if use_fuse_moe(x, qtype): if getattr(self, "gates", None) is None: gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts] up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts] @@ -310,7 +307,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: import xe_linear final_out = xe_linear.moe_forward_vec( x, topk_ids, topk_weight, self.gates, self.ups, self.downs, - x.size(-1), self.experts[0].intermediate_size, 2 + x.size(-1), self.experts[0].intermediate_size, qtype ) else: idxs = topk_ids.flatten().tolist() diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index d43fc51f..30bb5a29 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -394,3 +394,11 @@ def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor): new_sin = sin.contiguous() cos.set_(new_cos) sin.set_(new_sin) + + +def use_fuse_moe(hidden_states: torch.Tensor, qtype: int): + return ( + hidden_states.device.type == "xpu" + and hidden_states.dtype in [torch.float, torch.half] + and qtype in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["woq_int4"]] + )