parent
							
								
									dd46c141bd
								
							
						
					
					
						commit
						7e917d6cfb
					
				
					 2 changed files with 20 additions and 18 deletions
				
			
		| 
						 | 
					@ -49,10 +49,7 @@ if __name__ == '__main__':
 | 
				
			||||||
                                                 trust_remote_code=True,).to("xpu")
 | 
					                                                 trust_remote_code=True,).to("xpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load tokenizer
 | 
					    # Load tokenizer
 | 
				
			||||||
    if "qwen" in model_path.lower():
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Generate predicted tokens
 | 
					    # Generate predicted tokens
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,17 +19,21 @@ from typing import List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
					def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
				
			||||||
    new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
 | 
					    if hasattr(linears[0], "weight"):
 | 
				
			||||||
    if linears[0].bias is not None:
 | 
					        # For GPTQ model, it might be qweight
 | 
				
			||||||
        new_linear = torch.nn.Linear(0, 0, bias=True)
 | 
					        new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
 | 
				
			||||||
        new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
 | 
					        if linears[0].bias is not None:
 | 
				
			||||||
        new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
					            new_linear = torch.nn.Linear(0, 0, bias=True)
 | 
				
			||||||
 | 
					            new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
 | 
				
			||||||
 | 
					            new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            new_linear = torch.nn.Linear(0, 0, bias=False)
 | 
				
			||||||
 | 
					        new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
				
			||||||
 | 
					        new_linear.in_features = new_weight.size(1)
 | 
				
			||||||
 | 
					        new_linear.out_features = new_weight.size(0)
 | 
				
			||||||
 | 
					        return new_linear
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        new_linear = torch.nn.Linear(0, 0, bias=False)
 | 
					        return None
 | 
				
			||||||
    new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
					 | 
				
			||||||
    new_linear.in_features = new_weight.size(1)
 | 
					 | 
				
			||||||
    new_linear.out_features = new_weight.size(0)
 | 
					 | 
				
			||||||
    return new_linear
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
					def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
				
			||||||
| 
						 | 
					@ -39,8 +43,9 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
				
			||||||
            module.k_proj,
 | 
					            module.k_proj,
 | 
				
			||||||
            module.v_proj,
 | 
					            module.v_proj,
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
        module.qkv_proj = qkv_proj
 | 
					        if qkv_proj is not None:
 | 
				
			||||||
        del module.q_proj, module.k_proj, module.v_proj
 | 
					            module.qkv_proj = qkv_proj
 | 
				
			||||||
 | 
					            del module.q_proj, module.k_proj, module.v_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
 | 
					def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue