fix incompatibility between llama GW & llama pipeline (#12267)

* fix

* fix
This commit is contained in:
Ruonan Wang 2024-10-25 10:31:44 +08:00 committed by GitHub
parent b5e663854b
commit ae57e23e4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 18 additions and 11 deletions

View file

@ -59,7 +59,8 @@ if __name__ == "__main__":
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True, optimize_model=True,
pipeline=True, pipeline=True,
max_output_len=args.max_output_len) max_output_len=args.max_output_len,
attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@ -69,8 +70,8 @@ if __name__ == "__main__":
print("-" * 80) print("-" * 80)
print("done") print("done")
with torch.inference_mode(): with torch.inference_mode():
print("finish to load")
for i in range(5): for i in range(5):
print("finish to load")
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt") _input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("input length:", len(_input_ids[0])) print("input length:", len(_input_ids[0]))

View file

@ -246,15 +246,21 @@ def convert_llm(model: torch.nn.Module,
attn_layer = curr_layer.self_attn attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [ weights = []
(attn_layer.q_proj.weight, attn_layer.q_proj.scale), for q, k, v, o, g, u, d in zip(attn_layer.q_proj_dq_list,
(attn_layer.k_proj.weight, attn_layer.k_proj.scale), attn_layer.k_proj_dq_list,
(attn_layer.v_proj.weight, attn_layer.v_proj.scale), attn_layer.v_proj_dq_list,
(attn_layer.o_proj.weight, attn_layer.o_proj.scale), attn_layer.o_proj_dq_list,
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), mlp_layer.gate_proj_dq_list,
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), mlp_layer.up_proj_dq_list,
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), mlp_layer.down_proj_dq_list):
] weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
weights.append((d.weight, d.scale))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)