Add vllm glm and baichuan padding (#12053)
This commit is contained in:
parent
69c8d36f16
commit
5d3ab16a80
1 changed files with 38 additions and 24 deletions
|
|
@ -87,8 +87,9 @@ def get_load_function(low_bit):
|
||||||
scheduler_config=self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
)
|
)
|
||||||
if "qwen" in self.model_config.model.lower() and \
|
if "qwen" in self.model_config.model.lower() or \
|
||||||
self.model.model.layers[0].mlp.down_proj.input_size_per_partition % 256 != 0:
|
"baichuan" in self.model_config.model.lower() or \
|
||||||
|
"glm" in self.model_config.model.lower():
|
||||||
self.model.apply(padding_mlp)
|
self.model.apply(padding_mlp)
|
||||||
from ipex_llm import optimize_model
|
from ipex_llm import optimize_model
|
||||||
import os
|
import os
|
||||||
|
|
@ -114,27 +115,40 @@ def get_load_function(low_bit):
|
||||||
|
|
||||||
|
|
||||||
def padding_mlp(module: torch.nn.Module):
|
def padding_mlp(module: torch.nn.Module):
|
||||||
|
mlp_gate_up_name = None
|
||||||
|
mlp_down_name = None
|
||||||
if isinstance(module, Qwen2MLP):
|
if isinstance(module, Qwen2MLP):
|
||||||
hidden_size = module.down_proj.output_size
|
mlp_gate_up_name = "gate_up_proj"
|
||||||
|
mlp_down_name = "down_proj"
|
||||||
|
elif isinstance(module, GLMMLP):
|
||||||
|
mlp_gate_up_name = "dense_h_to_4h"
|
||||||
|
mlp_down_name = "dense_4h_to_h"
|
||||||
|
elif isinstance(module, BaiChuanMLP):
|
||||||
|
mlp_gate_up_name = "gate_up_proj"
|
||||||
|
mlp_down_name = "down_proj"
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
hidden_size = getattr(module, mlp_down_name).output_size
|
||||||
# devide by rank
|
# devide by rank
|
||||||
intermediate_size = module.down_proj.input_size_per_partition
|
intermediate_size = getattr(module, mlp_down_name).input_size_per_partition
|
||||||
padding_size = 256
|
padding_size = 256
|
||||||
padding_intermediate_size = \
|
padding_intermediate_size = \
|
||||||
(intermediate_size + padding_size - 1) // padding_size * padding_size
|
(intermediate_size + padding_size - 1) // padding_size * padding_size
|
||||||
if intermediate_size % padding_size == 0:
|
if intermediate_size % padding_size == 0:
|
||||||
return
|
return
|
||||||
gate_up_weight = module.gate_up_proj.weight.data
|
gate_up_weight = getattr(module, mlp_gate_up_name).weight.data
|
||||||
new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size],
|
new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size],
|
||||||
dtype=gate_up_weight.dtype, device=gate_up_weight.device)
|
dtype=gate_up_weight.dtype, device=gate_up_weight.device)
|
||||||
# merge_gate_up_weight
|
# merge_gate_up_weight
|
||||||
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
|
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
|
||||||
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
|
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
|
||||||
module.gate_up_proj.output_size_per_partition = padding_intermediate_size * 2
|
getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2
|
||||||
module.gate_up_proj.weight = torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
|
getattr(module, mlp_gate_up_name).weight = \
|
||||||
|
torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
|
||||||
|
|
||||||
down_weight = module.down_proj.weight.data
|
down_weight = getattr(module, mlp_down_name).weight.data
|
||||||
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
|
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
|
||||||
dtype=down_weight.dtype, device=down_weight.device)
|
dtype=down_weight.dtype, device=down_weight.device)
|
||||||
new_down_weight[:, :intermediate_size] = down_weight
|
new_down_weight[:, :intermediate_size] = down_weight
|
||||||
module.down_proj.input_size_per_partition = padding_intermediate_size
|
getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size
|
||||||
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
|
getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue