Add merge quantized qkv (#13160)
* add merge quantized qkv * fix style & device * add check
This commit is contained in:
		
							parent
							
								
									1e4e1353a0
								
							
						
					
					
						commit
						8ba57b41cd
					
				
					 1 changed files with 38 additions and 0 deletions
				
			
		| 
						 | 
					@ -373,3 +373,41 @@ def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bo
 | 
				
			||||||
        router_logits, top_k, norm_topk_prob
 | 
					        router_logits, top_k, norm_topk_prob
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    return selected_experts, routing_weights
 | 
					    return selected_experts, routing_weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# q,k,v_proj should be ipex-llm quantized linears
 | 
				
			||||||
 | 
					def merge_quantized_qkv(q_proj, k_proj, v_proj, module):
 | 
				
			||||||
 | 
					    from ipex_llm.transformers.low_bit_linear import FP4Params
 | 
				
			||||||
 | 
					    from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
 | 
					    has_qtype = (hasattr(q_proj.weight, 'qtype')
 | 
				
			||||||
 | 
					                 and hasattr(k_proj.weight, 'qtype')
 | 
				
			||||||
 | 
					                 and hasattr(v_proj.weight, 'qtype'))
 | 
				
			||||||
 | 
					    invalidInputError((has_qtype
 | 
				
			||||||
 | 
					                       and q_proj.weight.qtype == k_proj.weight.qtype
 | 
				
			||||||
 | 
					                       and q_proj.weight.qtype == v_proj.weight.qtype
 | 
				
			||||||
 | 
					                       and q_proj.weight.qtype in ggml_tensor_qtype.values()),
 | 
				
			||||||
 | 
					                      f"{q_proj.weight.qtype} is not supported, "
 | 
				
			||||||
 | 
					                      f"only {ggml_tensor_qtype.values()} are supported now.")
 | 
				
			||||||
 | 
					    origin_device = q_proj.weight.device
 | 
				
			||||||
 | 
					    q_proj.weight = q_proj.weight.to('cpu')
 | 
				
			||||||
 | 
					    k_proj.weight = k_proj.weight.to('cpu')
 | 
				
			||||||
 | 
					    v_proj.weight = v_proj.weight.to('cpu')
 | 
				
			||||||
 | 
					    linears = [q_proj, k_proj, v_proj]
 | 
				
			||||||
 | 
					    new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
 | 
				
			||||||
 | 
					    if q_proj.has_bias:
 | 
				
			||||||
 | 
					        new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
 | 
				
			||||||
 | 
					        q_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
				
			||||||
 | 
					    new_out_features = sum(layer.out_features for layer in linears)
 | 
				
			||||||
 | 
					    new_params_low_bit = FP4Params(data=new_weight.data,
 | 
				
			||||||
 | 
					                                   requires_grad=False,
 | 
				
			||||||
 | 
					                                   quantized=True,
 | 
				
			||||||
 | 
					                                   _shape=[new_out_features, q_proj.in_features],
 | 
				
			||||||
 | 
					                                   convert_shape_only=False,
 | 
				
			||||||
 | 
					                                   qtype=q_proj.weight.qtype,
 | 
				
			||||||
 | 
					                                   in_features=q_proj.in_features,
 | 
				
			||||||
 | 
					                                   enable_scale_search=False)
 | 
				
			||||||
 | 
					    q_proj.out_features = new_out_features
 | 
				
			||||||
 | 
					    q_proj.weight = new_params_low_bit.to(origin_device)
 | 
				
			||||||
 | 
					    del module.q_proj.weight
 | 
				
			||||||
 | 
					    module.qkv_proj = module.q_proj
 | 
				
			||||||
 | 
					    del module.k_proj, module.v_proj, module.q_proj
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue