Add vllm glm and baichuan padding (#12053)
This commit is contained in:
parent
69c8d36f16
commit
5d3ab16a80
1 changed files with 38 additions and 24 deletions
|
|
@ -87,8 +87,9 @@ def get_load_function(low_bit):
|
|||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config,
|
||||
)
|
||||
if "qwen" in self.model_config.model.lower() and \
|
||||
self.model.model.layers[0].mlp.down_proj.input_size_per_partition % 256 != 0:
|
||||
if "qwen" in self.model_config.model.lower() or \
|
||||
"baichuan" in self.model_config.model.lower() or \
|
||||
"glm" in self.model_config.model.lower():
|
||||
self.model.apply(padding_mlp)
|
||||
from ipex_llm import optimize_model
|
||||
import os
|
||||
|
|
@ -114,27 +115,40 @@ def get_load_function(low_bit):
|
|||
|
||||
|
||||
def padding_mlp(module: torch.nn.Module):
|
||||
mlp_gate_up_name = None
|
||||
mlp_down_name = None
|
||||
if isinstance(module, Qwen2MLP):
|
||||
hidden_size = module.down_proj.output_size
|
||||
mlp_gate_up_name = "gate_up_proj"
|
||||
mlp_down_name = "down_proj"
|
||||
elif isinstance(module, GLMMLP):
|
||||
mlp_gate_up_name = "dense_h_to_4h"
|
||||
mlp_down_name = "dense_4h_to_h"
|
||||
elif isinstance(module, BaiChuanMLP):
|
||||
mlp_gate_up_name = "gate_up_proj"
|
||||
mlp_down_name = "down_proj"
|
||||
else:
|
||||
return
|
||||
hidden_size = getattr(module, mlp_down_name).output_size
|
||||
# devide by rank
|
||||
intermediate_size = module.down_proj.input_size_per_partition
|
||||
intermediate_size = getattr(module, mlp_down_name).input_size_per_partition
|
||||
padding_size = 256
|
||||
padding_intermediate_size = \
|
||||
(intermediate_size + padding_size - 1) // padding_size * padding_size
|
||||
if intermediate_size % padding_size == 0:
|
||||
return
|
||||
gate_up_weight = module.gate_up_proj.weight.data
|
||||
gate_up_weight = getattr(module, mlp_gate_up_name).weight.data
|
||||
new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size],
|
||||
dtype=gate_up_weight.dtype, device=gate_up_weight.device)
|
||||
# merge_gate_up_weight
|
||||
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
|
||||
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
|
||||
module.gate_up_proj.output_size_per_partition = padding_intermediate_size * 2
|
||||
module.gate_up_proj.weight = torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
|
||||
getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2
|
||||
getattr(module, mlp_gate_up_name).weight = \
|
||||
torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
|
||||
|
||||
down_weight = module.down_proj.weight.data
|
||||
down_weight = getattr(module, mlp_down_name).weight.data
|
||||
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
|
||||
dtype=down_weight.dtype, device=down_weight.device)
|
||||
new_down_weight[:, :intermediate_size] = down_weight
|
||||
module.down_proj.input_size_per_partition = padding_intermediate_size
|
||||
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
|
||||
getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size
|
||||
getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue