diff --git a/python/llm/example/GPU/HuggingFace/Advanced-Quantizations/GPTQ/generate.py b/python/llm/example/GPU/HuggingFace/Advanced-Quantizations/GPTQ/generate.py index c45963f5..50041d41 100644 --- a/python/llm/example/GPU/HuggingFace/Advanced-Quantizations/GPTQ/generate.py +++ b/python/llm/example/GPU/HuggingFace/Advanced-Quantizations/GPTQ/generate.py @@ -47,13 +47,10 @@ if __name__ == '__main__': load_in_4bit=True, torch_dtype=torch.float, trust_remote_code=True,).to("xpu") - + # Load tokenizer - if "qwen" in model_path.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - else: - tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) - + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + # Generate predicted tokens with torch.inference_mode(): prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index e1522c4e..86b0d46b 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -19,17 +19,21 @@ from typing import List def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: - new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0) - if linears[0].bias is not None: - new_linear = torch.nn.Linear(0, 0, bias=True) - new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0) - new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False) + if hasattr(linears[0], "weight"): + # For GPTQ model, it might be qweight + new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0) + if linears[0].bias is not None: + new_linear = torch.nn.Linear(0, 0, bias=True) + new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0) + new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False) + else: + new_linear = torch.nn.Linear(0, 0, bias=False) + new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False) + new_linear.in_features = new_weight.size(1) + new_linear.out_features = new_weight.size(0) + return new_linear else: - new_linear = torch.nn.Linear(0, 0, bias=False) - new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False) - new_linear.in_features = new_weight.size(1) - new_linear.out_features = new_weight.size(0) - return new_linear + return None def merge_qkv_base(module: torch.nn.Module, attention_class): @@ -39,8 +43,9 @@ def merge_qkv_base(module: torch.nn.Module, attention_class): module.k_proj, module.v_proj, ]) - module.qkv_proj = qkv_proj - del module.q_proj, module.k_proj, module.v_proj + if qkv_proj is not None: + module.qkv_proj = qkv_proj + del module.q_proj, module.k_proj, module.v_proj def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):