parent
22f09f618a
commit
4c3e493b2d
1 changed files with 10 additions and 1 deletions
|
|
@ -69,6 +69,15 @@ def merge_qkv(module: torch.nn.Module):
|
||||||
module.v_proj.weight.data,
|
module.v_proj.weight.data,
|
||||||
], dim=0)
|
], dim=0)
|
||||||
|
|
||||||
|
if module.q_proj.bias is not None:
|
||||||
|
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||||
|
new_bias = torch.cat([
|
||||||
|
module.q_proj.bias.data,
|
||||||
|
module.k_proj.bias.data,
|
||||||
|
module.v_proj.bias.data,
|
||||||
|
], dim=0)
|
||||||
|
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||||
|
else:
|
||||||
qkv_proj = torch.nn.Linear(0, 0, bias=False)
|
qkv_proj = torch.nn.Linear(0, 0, bias=False)
|
||||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||||
qkv_proj.in_features = new_weight.size(1)
|
qkv_proj.in_features = new_weight.size(1)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue