Add vllm glm and baichuan padding (#12053)

This commit is contained in:
Wang, Jian4 2024-09-10 15:57:28 +08:00 committed by GitHub
parent 69c8d36f16
commit 5d3ab16a80
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -87,8 +87,9 @@ def get_load_function(low_bit):
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
if "qwen" in self.model_config.model.lower() and \
self.model.model.layers[0].mlp.down_proj.input_size_per_partition % 256 != 0:
if "qwen" in self.model_config.model.lower() or \
"baichuan" in self.model_config.model.lower() or \
"glm" in self.model_config.model.lower():
self.model.apply(padding_mlp)
from ipex_llm import optimize_model
import os
@ -114,27 +115,40 @@ def get_load_function(low_bit):
def padding_mlp(module: torch.nn.Module):
mlp_gate_up_name = None
mlp_down_name = None
if isinstance(module, Qwen2MLP):
hidden_size = module.down_proj.output_size
mlp_gate_up_name = "gate_up_proj"
mlp_down_name = "down_proj"
elif isinstance(module, GLMMLP):
mlp_gate_up_name = "dense_h_to_4h"
mlp_down_name = "dense_4h_to_h"
elif isinstance(module, BaiChuanMLP):
mlp_gate_up_name = "gate_up_proj"
mlp_down_name = "down_proj"
else:
return
hidden_size = getattr(module, mlp_down_name).output_size
# devide by rank
intermediate_size = module.down_proj.input_size_per_partition
intermediate_size = getattr(module, mlp_down_name).input_size_per_partition
padding_size = 256
padding_intermediate_size = \
(intermediate_size + padding_size - 1) // padding_size * padding_size
if intermediate_size % padding_size == 0:
return
gate_up_weight = module.gate_up_proj.weight.data
gate_up_weight = getattr(module, mlp_gate_up_name).weight.data
new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size],
dtype=gate_up_weight.dtype, device=gate_up_weight.device)
# merge_gate_up_weight
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
module.gate_up_proj.output_size_per_partition = padding_intermediate_size * 2
module.gate_up_proj.weight = torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2
getattr(module, mlp_gate_up_name).weight = \
torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
down_weight = module.down_proj.weight.data
down_weight = getattr(module, mlp_down_name).weight.data
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
dtype=down_weight.dtype, device=down_weight.device)
new_down_weight[:, :intermediate_size] = down_weight
module.down_proj.input_size_per_partition = padding_intermediate_size
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size
getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)