parent
							
								
									dd46c141bd
								
							
						
					
					
						commit
						7e917d6cfb
					
				
					 2 changed files with 20 additions and 18 deletions
				
			
		| 
						 | 
					@ -49,10 +49,7 @@ if __name__ == '__main__':
 | 
				
			||||||
                                                 trust_remote_code=True,).to("xpu")
 | 
					                                                 trust_remote_code=True,).to("xpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load tokenizer
 | 
					    # Load tokenizer
 | 
				
			||||||
    if "qwen" in model_path.lower():
 | 
					 | 
				
			||||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Generate predicted tokens
 | 
					    # Generate predicted tokens
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,8 @@ from typing import List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
					def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
				
			||||||
 | 
					    if hasattr(linears[0], "weight"):
 | 
				
			||||||
 | 
					        # For GPTQ model, it might be qweight
 | 
				
			||||||
        new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
 | 
					        new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
 | 
				
			||||||
        if linears[0].bias is not None:
 | 
					        if linears[0].bias is not None:
 | 
				
			||||||
            new_linear = torch.nn.Linear(0, 0, bias=True)
 | 
					            new_linear = torch.nn.Linear(0, 0, bias=True)
 | 
				
			||||||
| 
						 | 
					@ -30,6 +32,8 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
				
			||||||
        new_linear.in_features = new_weight.size(1)
 | 
					        new_linear.in_features = new_weight.size(1)
 | 
				
			||||||
        new_linear.out_features = new_weight.size(0)
 | 
					        new_linear.out_features = new_weight.size(0)
 | 
				
			||||||
        return new_linear
 | 
					        return new_linear
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
					def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
				
			||||||
| 
						 | 
					@ -39,6 +43,7 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
				
			||||||
            module.k_proj,
 | 
					            module.k_proj,
 | 
				
			||||||
            module.v_proj,
 | 
					            module.v_proj,
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
 | 
					        if qkv_proj is not None:
 | 
				
			||||||
            module.qkv_proj = qkv_proj
 | 
					            module.qkv_proj = qkv_proj
 | 
				
			||||||
            del module.q_proj, module.k_proj, module.v_proj
 | 
					            del module.q_proj, module.k_proj, module.v_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue