Add merge quantized qkv (#13160)
* add merge quantized qkv * fix style & device * add check
This commit is contained in:
parent
1e4e1353a0
commit
8ba57b41cd
1 changed files with 38 additions and 0 deletions
|
|
@ -373,3 +373,41 @@ def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bo
|
||||||
router_logits, top_k, norm_topk_prob
|
router_logits, top_k, norm_topk_prob
|
||||||
)
|
)
|
||||||
return selected_experts, routing_weights
|
return selected_experts, routing_weights
|
||||||
|
|
||||||
|
|
||||||
|
# q,k,v_proj should be ipex-llm quantized linears
|
||||||
|
def merge_quantized_qkv(q_proj, k_proj, v_proj, module):
|
||||||
|
from ipex_llm.transformers.low_bit_linear import FP4Params
|
||||||
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
has_qtype = (hasattr(q_proj.weight, 'qtype')
|
||||||
|
and hasattr(k_proj.weight, 'qtype')
|
||||||
|
and hasattr(v_proj.weight, 'qtype'))
|
||||||
|
invalidInputError((has_qtype
|
||||||
|
and q_proj.weight.qtype == k_proj.weight.qtype
|
||||||
|
and q_proj.weight.qtype == v_proj.weight.qtype
|
||||||
|
and q_proj.weight.qtype in ggml_tensor_qtype.values()),
|
||||||
|
f"{q_proj.weight.qtype} is not supported, "
|
||||||
|
f"only {ggml_tensor_qtype.values()} are supported now.")
|
||||||
|
origin_device = q_proj.weight.device
|
||||||
|
q_proj.weight = q_proj.weight.to('cpu')
|
||||||
|
k_proj.weight = k_proj.weight.to('cpu')
|
||||||
|
v_proj.weight = v_proj.weight.to('cpu')
|
||||||
|
linears = [q_proj, k_proj, v_proj]
|
||||||
|
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
|
||||||
|
if q_proj.has_bias:
|
||||||
|
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
|
||||||
|
q_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||||
|
new_out_features = sum(layer.out_features for layer in linears)
|
||||||
|
new_params_low_bit = FP4Params(data=new_weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=True,
|
||||||
|
_shape=[new_out_features, q_proj.in_features],
|
||||||
|
convert_shape_only=False,
|
||||||
|
qtype=q_proj.weight.qtype,
|
||||||
|
in_features=q_proj.in_features,
|
||||||
|
enable_scale_search=False)
|
||||||
|
q_proj.out_features = new_out_features
|
||||||
|
q_proj.weight = new_params_low_bit.to(origin_device)
|
||||||
|
del module.q_proj.weight
|
||||||
|
module.qkv_proj = module.q_proj
|
||||||
|
del module.k_proj, module.v_proj, module.q_proj
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue