diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 030c3cd2..89a8fa67 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -143,6 +143,7 @@ def padding_mlp(module: torch.nn.Module): new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :] new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2 + getattr(module, mlp_gate_up_name).output_size = padding_intermediate_size * 2 getattr(module, mlp_gate_up_name).weight = \ torch.nn.Parameter(new_gate_up_weight, requires_grad=False) @@ -151,4 +152,5 @@ def padding_mlp(module: torch.nn.Module): dtype=down_weight.dtype, device=down_weight.device) new_down_weight[:, :intermediate_size] = down_weight getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size + getattr(module, mlp_down_name).input_size = padding_intermediate_size getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)