small refactor and fix (#13101)
This commit is contained in:
parent
14cd613fe1
commit
908fdb982e
3 changed files with 21 additions and 13 deletions
|
|
@ -775,12 +775,15 @@ class _BaseAutoModelClass:
|
||||||
model)
|
model)
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
try:
|
||||||
# add lookup_generate to loaded model
|
# add lookup_generate to loaded model
|
||||||
from .lookup import lookup_generate
|
from .lookup import lookup_generate
|
||||||
import types
|
import types
|
||||||
model.lookup_generate = types.MethodType(lookup_generate, model)
|
model.lookup_generate = types.MethodType(lookup_generate, model)
|
||||||
if model.config.model_type == "minicpmv" and hasattr(model, 'llm'):
|
if model.config.model_type == "minicpmv" and hasattr(model, 'llm'):
|
||||||
model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm)
|
model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm)
|
||||||
|
except ImportError as e:
|
||||||
|
pass
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache
|
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||||
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
|
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
|
||||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||||
from ipex_llm.transformers.models.utils import rotate_half
|
from ipex_llm.transformers.models.utils import rotate_half, use_fuse_moe
|
||||||
|
|
||||||
|
|
||||||
def padding_mla_v_hd(module: torch.nn.Module):
|
def padding_mla_v_hd(module: torch.nn.Module):
|
||||||
|
|
@ -291,11 +291,8 @@ def fuse_gate_forward(self, x: torch.Tensor):
|
||||||
|
|
||||||
|
|
||||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
||||||
if (
|
qtype = self.experts[0].down_proj.qtype
|
||||||
x.device.type == "xpu"
|
if use_fuse_moe(x, qtype):
|
||||||
and x.dtype in [torch.float, torch.half]
|
|
||||||
and self.experts[0].down_proj.qtype == 2
|
|
||||||
):
|
|
||||||
if getattr(self, "gates", None) is None:
|
if getattr(self, "gates", None) is None:
|
||||||
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
|
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
|
||||||
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
|
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
|
||||||
|
|
@ -310,7 +307,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
|
||||||
import xe_linear
|
import xe_linear
|
||||||
final_out = xe_linear.moe_forward_vec(
|
final_out = xe_linear.moe_forward_vec(
|
||||||
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
|
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
|
||||||
x.size(-1), self.experts[0].intermediate_size, 2
|
x.size(-1), self.experts[0].intermediate_size, qtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
idxs = topk_ids.flatten().tolist()
|
idxs = topk_ids.flatten().tolist()
|
||||||
|
|
|
||||||
|
|
@ -394,3 +394,11 @@ def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
|
||||||
new_sin = sin.contiguous()
|
new_sin = sin.contiguous()
|
||||||
cos.set_(new_cos)
|
cos.set_(new_cos)
|
||||||
sin.set_(new_sin)
|
sin.set_(new_sin)
|
||||||
|
|
||||||
|
|
||||||
|
def use_fuse_moe(hidden_states: torch.Tensor, qtype: int):
|
||||||
|
return (
|
||||||
|
hidden_states.device.type == "xpu"
|
||||||
|
and hidden_states.dtype in [torch.float, torch.half]
|
||||||
|
and qtype in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["woq_int4"]]
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue