Update for vllm one card padding (#12058)

This commit is contained in:
Wang, Jian4 2024-09-11 10:52:55 +08:00 committed by GitHub
parent c5fdfde1bd
commit 30a8680645
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -143,6 +143,7 @@ def padding_mlp(module: torch.nn.Module):
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :] new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2 getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2
getattr(module, mlp_gate_up_name).output_size = padding_intermediate_size * 2
getattr(module, mlp_gate_up_name).weight = \ getattr(module, mlp_gate_up_name).weight = \
torch.nn.Parameter(new_gate_up_weight, requires_grad=False) torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
@ -151,4 +152,5 @@ def padding_mlp(module: torch.nn.Module):
dtype=down_weight.dtype, device=down_weight.device) dtype=down_weight.dtype, device=down_weight.device)
new_down_weight[:, :intermediate_size] = down_weight new_down_weight[:, :intermediate_size] = down_weight
getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size
getattr(module, mlp_down_name).input_size = padding_intermediate_size
getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False) getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)