diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 23e6fd97..030c3cd2 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -87,8 +87,9 @@ def get_load_function(low_bit): scheduler_config=self.scheduler_config, cache_config=self.cache_config, ) - if "qwen" in self.model_config.model.lower() and \ - self.model.model.layers[0].mlp.down_proj.input_size_per_partition % 256 != 0: + if "qwen" in self.model_config.model.lower() or \ + "baichuan" in self.model_config.model.lower() or \ + "glm" in self.model_config.model.lower(): self.model.apply(padding_mlp) from ipex_llm import optimize_model import os @@ -114,27 +115,40 @@ def get_load_function(low_bit): def padding_mlp(module: torch.nn.Module): + mlp_gate_up_name = None + mlp_down_name = None if isinstance(module, Qwen2MLP): - hidden_size = module.down_proj.output_size - # devide by rank - intermediate_size = module.down_proj.input_size_per_partition - padding_size = 256 - padding_intermediate_size = \ - (intermediate_size + padding_size - 1) // padding_size * padding_size - if intermediate_size % padding_size == 0: - return - gate_up_weight = module.gate_up_proj.weight.data - new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size], - dtype=gate_up_weight.dtype, device=gate_up_weight.device) - # merge_gate_up_weight - new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :] - new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa - module.gate_up_proj.output_size_per_partition = padding_intermediate_size * 2 - module.gate_up_proj.weight = torch.nn.Parameter(new_gate_up_weight, requires_grad=False) + mlp_gate_up_name = "gate_up_proj" + mlp_down_name = "down_proj" + elif isinstance(module, GLMMLP): + mlp_gate_up_name = "dense_h_to_4h" + mlp_down_name = "dense_4h_to_h" + elif isinstance(module, BaiChuanMLP): + mlp_gate_up_name = "gate_up_proj" + mlp_down_name = "down_proj" + else: + return + hidden_size = getattr(module, mlp_down_name).output_size + # devide by rank + intermediate_size = getattr(module, mlp_down_name).input_size_per_partition + padding_size = 256 + padding_intermediate_size = \ + (intermediate_size + padding_size - 1) // padding_size * padding_size + if intermediate_size % padding_size == 0: + return + gate_up_weight = getattr(module, mlp_gate_up_name).weight.data + new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size], + dtype=gate_up_weight.dtype, device=gate_up_weight.device) + # merge_gate_up_weight + new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :] + new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa + getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2 + getattr(module, mlp_gate_up_name).weight = \ + torch.nn.Parameter(new_gate_up_weight, requires_grad=False) - down_weight = module.down_proj.weight.data - new_down_weight = torch.zeros([hidden_size, padding_intermediate_size], - dtype=down_weight.dtype, device=down_weight.device) - new_down_weight[:, :intermediate_size] = down_weight - module.down_proj.input_size_per_partition = padding_intermediate_size - module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False) + down_weight = getattr(module, mlp_down_name).weight.data + new_down_weight = torch.zeros([hidden_size, padding_intermediate_size], + dtype=down_weight.dtype, device=down_weight.device) + new_down_weight[:, :intermediate_size] = down_weight + getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size + getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)