fix stablelm2 1.6b (#10656)

* fix stablelm2 1.6b

* meet code review
This commit is contained in:
Xin Qiu 2024-04-03 22:15:32 +08:00 committed by GitHub
parent 22f09f618a
commit 4c3e493b2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -69,7 +69,16 @@ def merge_qkv(module: torch.nn.Module):
module.v_proj.weight.data,
], dim=0)
qkv_proj = torch.nn.Linear(0, 0, bias=False)
if module.q_proj.bias is not None:
qkv_proj = torch.nn.Linear(0, 0, bias=True)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=0)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
else:
qkv_proj = torch.nn.Linear(0, 0, bias=False)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)