From 4c3e493b2dde203f2a484d4bc7fceee4947e6279 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Wed, 3 Apr 2024 22:15:32 +0800 Subject: [PATCH] fix stablelm2 1.6b (#10656) * fix stablelm2 1.6b * meet code review --- .../llm/src/ipex_llm/transformers/models/stablelm.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 3e96e4bf..e15f8974 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -69,7 +69,16 @@ def merge_qkv(module: torch.nn.Module): module.v_proj.weight.data, ], dim=0) - qkv_proj = torch.nn.Linear(0, 0, bias=False) + if module.q_proj.bias is not None: + qkv_proj = torch.nn.Linear(0, 0, bias=True) + new_bias = torch.cat([ + module.q_proj.bias.data, + module.k_proj.bias.data, + module.v_proj.bias.data, + ], dim=0) + qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) + else: + qkv_proj = torch.nn.Linear(0, 0, bias=False) qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) qkv_proj.in_features = new_weight.size(1) qkv_proj.out_features = new_weight.size(0)