parent
dd46c141bd
commit
7e917d6cfb
2 changed files with 20 additions and 18 deletions
|
|
@ -47,13 +47,10 @@ if __name__ == '__main__':
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
torch_dtype=torch.float,
|
torch_dtype=torch.float,
|
||||||
trust_remote_code=True,).to("xpu")
|
trust_remote_code=True,).to("xpu")
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
if "qwen" in model_path.lower():
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
||||||
else:
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
# Generate predicted tokens
|
# Generate predicted tokens
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
|
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||||
|
|
|
||||||
|
|
@ -19,17 +19,21 @@ from typing import List
|
||||||
|
|
||||||
|
|
||||||
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
||||||
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
|
if hasattr(linears[0], "weight"):
|
||||||
if linears[0].bias is not None:
|
# For GPTQ model, it might be qweight
|
||||||
new_linear = torch.nn.Linear(0, 0, bias=True)
|
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
|
||||||
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
|
if linears[0].bias is not None:
|
||||||
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
new_linear = torch.nn.Linear(0, 0, bias=True)
|
||||||
|
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
|
||||||
|
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||||
|
else:
|
||||||
|
new_linear = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||||
|
new_linear.in_features = new_weight.size(1)
|
||||||
|
new_linear.out_features = new_weight.size(0)
|
||||||
|
return new_linear
|
||||||
else:
|
else:
|
||||||
new_linear = torch.nn.Linear(0, 0, bias=False)
|
return None
|
||||||
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
|
||||||
new_linear.in_features = new_weight.size(1)
|
|
||||||
new_linear.out_features = new_weight.size(0)
|
|
||||||
return new_linear
|
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv_base(module: torch.nn.Module, attention_class):
|
def merge_qkv_base(module: torch.nn.Module, attention_class):
|
||||||
|
|
@ -39,8 +43,9 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
|
||||||
module.k_proj,
|
module.k_proj,
|
||||||
module.v_proj,
|
module.v_proj,
|
||||||
])
|
])
|
||||||
module.qkv_proj = qkv_proj
|
if qkv_proj is not None:
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
module.qkv_proj = qkv_proj
|
||||||
|
del module.q_proj, module.k_proj, module.v_proj
|
||||||
|
|
||||||
|
|
||||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue