small refactor and fix (#13101)
This commit is contained in:
		
							parent
							
								
									14cd613fe1
								
							
						
					
					
						commit
						908fdb982e
					
				
					 3 changed files with 21 additions and 13 deletions
				
			
		| 
						 | 
				
			
			@ -775,12 +775,15 @@ class _BaseAutoModelClass:
 | 
			
		|||
                                                                model)
 | 
			
		||||
            torch.distributed.barrier()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # add lookup_generate to loaded model
 | 
			
		||||
            from .lookup import lookup_generate
 | 
			
		||||
            import types
 | 
			
		||||
            model.lookup_generate = types.MethodType(lookup_generate, model)
 | 
			
		||||
            if model.config.model_type == "minicpmv" and hasattr(model, 'llm'):
 | 
			
		||||
                model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm)
 | 
			
		||||
        except ImportError as e:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		|||
from ipex_llm.transformers.kv import DynamicNormalCache
 | 
			
		||||
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import rotate_half
 | 
			
		||||
from ipex_llm.transformers.models.utils import rotate_half, use_fuse_moe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_mla_v_hd(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			@ -291,11 +291,8 @@ def fuse_gate_forward(self, x: torch.Tensor):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
 | 
			
		||||
    if (
 | 
			
		||||
        x.device.type == "xpu"
 | 
			
		||||
        and x.dtype in [torch.float, torch.half]
 | 
			
		||||
        and self.experts[0].down_proj.qtype == 2
 | 
			
		||||
    ):
 | 
			
		||||
    qtype = self.experts[0].down_proj.qtype
 | 
			
		||||
    if use_fuse_moe(x, qtype):
 | 
			
		||||
        if getattr(self, "gates", None) is None:
 | 
			
		||||
            gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
 | 
			
		||||
            up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
 | 
			
		||||
| 
						 | 
				
			
			@ -310,7 +307,7 @@ def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight:
 | 
			
		|||
        import xe_linear
 | 
			
		||||
        final_out = xe_linear.moe_forward_vec(
 | 
			
		||||
            x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
 | 
			
		||||
            x.size(-1), self.experts[0].intermediate_size, 2
 | 
			
		||||
            x.size(-1), self.experts[0].intermediate_size, qtype
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        idxs = topk_ids.flatten().tolist()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -394,3 +394,11 @@ def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
 | 
			
		|||
        new_sin = sin.contiguous()
 | 
			
		||||
        cos.set_(new_cos)
 | 
			
		||||
        sin.set_(new_sin)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_fuse_moe(hidden_states: torch.Tensor, qtype: int):
 | 
			
		||||
    return (
 | 
			
		||||
        hidden_states.device.type == "xpu"
 | 
			
		||||
        and hidden_states.dtype in [torch.float, torch.half]
 | 
			
		||||
        and qtype in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["woq_int4"]]
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue