From 30a8680645f022b7f4dc068388dd063166286801 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:52:55 +0800 Subject: [PATCH] Update for vllm one card padding (#12058) --- python/llm/src/ipex_llm/vllm/xpu/model_convert.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 030c3cd2..89a8fa67 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -143,6 +143,7 @@ def padding_mlp(module: torch.nn.Module): new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :] new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2 + getattr(module, mlp_gate_up_name).output_size = padding_intermediate_size * 2 getattr(module, mlp_gate_up_name).weight = \ torch.nn.Parameter(new_gate_up_weight, requires_grad=False) @@ -151,4 +152,5 @@ def padding_mlp(module: torch.nn.Module): dtype=down_weight.dtype, device=down_weight.device) new_down_weight[:, :intermediate_size] = down_weight getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size + getattr(module, mlp_down_name).input_size = padding_intermediate_size getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)