diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 8117db90..11777464 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -373,3 +373,41 @@ def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bo router_logits, top_k, norm_topk_prob ) return selected_experts, routing_weights + + +# q,k,v_proj should be ipex-llm quantized linears +def merge_quantized_qkv(q_proj, k_proj, v_proj, module): + from ipex_llm.transformers.low_bit_linear import FP4Params + from ipex_llm.ggml.quantize import ggml_tensor_qtype + has_qtype = (hasattr(q_proj.weight, 'qtype') + and hasattr(k_proj.weight, 'qtype') + and hasattr(v_proj.weight, 'qtype')) + invalidInputError((has_qtype + and q_proj.weight.qtype == k_proj.weight.qtype + and q_proj.weight.qtype == v_proj.weight.qtype + and q_proj.weight.qtype in ggml_tensor_qtype.values()), + f"{q_proj.weight.qtype} is not supported, " + f"only {ggml_tensor_qtype.values()} are supported now.") + origin_device = q_proj.weight.device + q_proj.weight = q_proj.weight.to('cpu') + k_proj.weight = k_proj.weight.to('cpu') + v_proj.weight = v_proj.weight.to('cpu') + linears = [q_proj, k_proj, v_proj] + new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0) + if q_proj.has_bias: + new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0) + q_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) + new_out_features = sum(layer.out_features for layer in linears) + new_params_low_bit = FP4Params(data=new_weight.data, + requires_grad=False, + quantized=True, + _shape=[new_out_features, q_proj.in_features], + convert_shape_only=False, + qtype=q_proj.weight.qtype, + in_features=q_proj.in_features, + enable_scale_search=False) + q_proj.out_features = new_out_features + q_proj.weight = new_params_low_bit.to(origin_device) + del module.q_proj.weight + module.qkv_proj = module.q_proj + del module.k_proj, module.v_proj, module.q_proj