Update for vllm one card padding (#12058)
This commit is contained in:
parent
c5fdfde1bd
commit
30a8680645
1 changed files with 2 additions and 0 deletions
|
|
@ -143,6 +143,7 @@ def padding_mlp(module: torch.nn.Module):
|
|||
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
|
||||
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
|
||||
getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2
|
||||
getattr(module, mlp_gate_up_name).output_size = padding_intermediate_size * 2
|
||||
getattr(module, mlp_gate_up_name).weight = \
|
||||
torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
|
||||
|
||||
|
|
@ -151,4 +152,5 @@ def padding_mlp(module: torch.nn.Module):
|
|||
dtype=down_weight.dtype, device=down_weight.device)
|
||||
new_down_weight[:, :intermediate_size] = down_weight
|
||||
getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size
|
||||
getattr(module, mlp_down_name).input_size = padding_intermediate_size
|
||||
getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue