parent
							
								
									22f09f618a
								
							
						
					
					
						commit
						4c3e493b2d
					
				
					 1 changed files with 10 additions and 1 deletions
				
			
		| 
						 | 
					@ -69,7 +69,16 @@ def merge_qkv(module: torch.nn.Module):
 | 
				
			||||||
            module.v_proj.weight.data,
 | 
					            module.v_proj.weight.data,
 | 
				
			||||||
        ], dim=0)
 | 
					        ], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        qkv_proj = torch.nn.Linear(0, 0, bias=False)
 | 
					        if module.q_proj.bias is not None:
 | 
				
			||||||
 | 
					            qkv_proj = torch.nn.Linear(0, 0, bias=True)
 | 
				
			||||||
 | 
					            new_bias = torch.cat([
 | 
				
			||||||
 | 
					                module.q_proj.bias.data,
 | 
				
			||||||
 | 
					                module.k_proj.bias.data,
 | 
				
			||||||
 | 
					                module.v_proj.bias.data,
 | 
				
			||||||
 | 
					            ], dim=0)
 | 
				
			||||||
 | 
					            qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            qkv_proj = torch.nn.Linear(0, 0, bias=False)
 | 
				
			||||||
        qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
					        qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
				
			||||||
        qkv_proj.in_features = new_weight.size(1)
 | 
					        qkv_proj.in_features = new_weight.size(1)
 | 
				
			||||||
        qkv_proj.out_features = new_weight.size(0)
 | 
					        qkv_proj.out_features = new_weight.size(0)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue