fix gptq of llama (#11749)

* fix gptq of llama

* small fix
This commit is contained in:
Ruonan Wang 2024-08-09 11:39:25 +03:00 committed by GitHub
parent dd46c141bd
commit 7e917d6cfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 20 additions and 18 deletions

View file

@ -47,13 +47,10 @@ if __name__ == '__main__':
load_in_4bit=True, load_in_4bit=True,
torch_dtype=torch.float, torch_dtype=torch.float,
trust_remote_code=True,).to("xpu") trust_remote_code=True,).to("xpu")
# Load tokenizer # Load tokenizer
if "qwen" in model_path.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Generate predicted tokens # Generate predicted tokens
with torch.inference_mode(): with torch.inference_mode():
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt) prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)

View file

@ -19,17 +19,21 @@ from typing import List
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0) if hasattr(linears[0], "weight"):
if linears[0].bias is not None: # For GPTQ model, it might be qweight
new_linear = torch.nn.Linear(0, 0, bias=True) new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0) if linears[0].bias is not None:
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False) new_linear = torch.nn.Linear(0, 0, bias=True)
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
else:
new_linear = torch.nn.Linear(0, 0, bias=False)
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
new_linear.in_features = new_weight.size(1)
new_linear.out_features = new_weight.size(0)
return new_linear
else: else:
new_linear = torch.nn.Linear(0, 0, bias=False) return None
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
new_linear.in_features = new_weight.size(1)
new_linear.out_features = new_weight.size(0)
return new_linear
def merge_qkv_base(module: torch.nn.Module, attention_class): def merge_qkv_base(module: torch.nn.Module, attention_class):
@ -39,8 +43,9 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
module.k_proj, module.k_proj,
module.v_proj, module.v_proj,
]) ])
module.qkv_proj = qkv_proj if qkv_proj is not None:
del module.q_proj, module.k_proj, module.v_proj module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor): def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):