small refactor and fix (#13101)
This commit is contained in:
parent
14cd613fe1
commit
908fdb982e
3 changed files with 21 additions and 13 deletions
|
|
@ -775,12 +775,15 @@ class _BaseAutoModelClass:
|
|||
model)
|
||||
torch.distributed.barrier()
|
||||
|
||||
# add lookup_generate to loaded model
|
||||
from .lookup import lookup_generate
|
||||
import types
|
||||
model.lookup_generate = types.MethodType(lookup_generate, model)
|
||||
if model.config.model_type == "minicpmv" and hasattr(model, 'llm'):
|
||||
model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm)
|
||||
try:
|
||||
# add lookup_generate to loaded model
|
||||
from .lookup import lookup_generate
|
||||
import types
|
||||
model.lookup_generate = types.MethodType(lookup_generate, model)
|
||||
if model.config.model_type == "minicpmv" and hasattr(model, 'llm'):
|
||||
model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm)
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError
|
|||
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
from ipex_llm.transformers.models.utils import rotate_half
|
||||
from ipex_llm.transformers.models.utils import rotate_half, use_fuse_moe
|
||||
|
||||
|
||||
def padding_mla_v_hd(module: torch.nn.Module):
|
||||
|
|
@ -291,11 +291,8 @@ def fuse_gate_forward(self, x: torch.Tensor):
|
|||
|
||||
|
||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
||||
if (
|
||||
x.device.type == "xpu"
|
||||
and x.dtype in [torch.float, torch.half]
|
||||
and self.experts[0].down_proj.qtype == 2
|
||||
):
|
||||
qtype = self.experts[0].down_proj.qtype
|
||||
if use_fuse_moe(x, qtype):
|
||||
if getattr(self, "gates", None) is None:
|
||||
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
|
||||
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
|
||||
|
|
@ -310,7 +307,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
|
|||
import xe_linear
|
||||
final_out = xe_linear.moe_forward_vec(
|
||||
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
|
||||
x.size(-1), self.experts[0].intermediate_size, 2
|
||||
x.size(-1), self.experts[0].intermediate_size, qtype
|
||||
)
|
||||
else:
|
||||
idxs = topk_ids.flatten().tolist()
|
||||
|
|
|
|||
|
|
@ -394,3 +394,11 @@ def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
|
|||
new_sin = sin.contiguous()
|
||||
cos.set_(new_cos)
|
||||
sin.set_(new_sin)
|
||||
|
||||
|
||||
def use_fuse_moe(hidden_states: torch.Tensor, qtype: int):
|
||||
return (
|
||||
hidden_states.device.type == "xpu"
|
||||
and hidden_states.dtype in [torch.float, torch.half]
|
||||
and qtype in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["woq_int4"]]
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue