Add merge quantized qkv (#13160)

* add merge quantized qkv

* fix style & device

* add check
This commit is contained in:
Yina Chen 2025-05-16 15:46:47 +08:00 committed by GitHub
parent 1e4e1353a0
commit 8ba57b41cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -373,3 +373,41 @@ def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bo
router_logits, top_k, norm_topk_prob router_logits, top_k, norm_topk_prob
) )
return selected_experts, routing_weights return selected_experts, routing_weights
# q,k,v_proj should be ipex-llm quantized linears
def merge_quantized_qkv(q_proj, k_proj, v_proj, module):
from ipex_llm.transformers.low_bit_linear import FP4Params
from ipex_llm.ggml.quantize import ggml_tensor_qtype
has_qtype = (hasattr(q_proj.weight, 'qtype')
and hasattr(k_proj.weight, 'qtype')
and hasattr(v_proj.weight, 'qtype'))
invalidInputError((has_qtype
and q_proj.weight.qtype == k_proj.weight.qtype
and q_proj.weight.qtype == v_proj.weight.qtype
and q_proj.weight.qtype in ggml_tensor_qtype.values()),
f"{q_proj.weight.qtype} is not supported, "
f"only {ggml_tensor_qtype.values()} are supported now.")
origin_device = q_proj.weight.device
q_proj.weight = q_proj.weight.to('cpu')
k_proj.weight = k_proj.weight.to('cpu')
v_proj.weight = v_proj.weight.to('cpu')
linears = [q_proj, k_proj, v_proj]
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
if q_proj.has_bias:
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
q_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
new_out_features = sum(layer.out_features for layer in linears)
new_params_low_bit = FP4Params(data=new_weight.data,
requires_grad=False,
quantized=True,
_shape=[new_out_features, q_proj.in_features],
convert_shape_only=False,
qtype=q_proj.weight.qtype,
in_features=q_proj.in_features,
enable_scale_search=False)
q_proj.out_features = new_out_features
q_proj.weight = new_params_low_bit.to(origin_device)
del module.q_proj.weight
module.qkv_proj = module.q_proj
del module.k_proj, module.v_proj, module.q_proj