diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 3e96e4bf..e15f8974 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -69,7 +69,16 @@ def merge_qkv(module: torch.nn.Module): module.v_proj.weight.data, ], dim=0) - qkv_proj = torch.nn.Linear(0, 0, bias=False) + if module.q_proj.bias is not None: + qkv_proj = torch.nn.Linear(0, 0, bias=True) + new_bias = torch.cat([ + module.q_proj.bias.data, + module.k_proj.bias.data, + module.v_proj.bias.data, + ], dim=0) + qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) + else: + qkv_proj = torch.nn.Linear(0, 0, bias=False) qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) qkv_proj.in_features = new_weight.size(1) qkv_proj.out_features = new_weight.size(0)