vLLM: set convert_to_half to False by default (#13172)
* init * remove * fix
This commit is contained in:
		
							parent
							
								
									1576347892
								
							
						
					
					
						commit
						154af7d7f7
					
				
					 2 changed files with 3 additions and 1 deletions
				
			
		| 
						 | 
					@ -293,6 +293,7 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
 | 
				
			||||||
                mp_group=mp_group,
 | 
					                mp_group=mp_group,
 | 
				
			||||||
                optimize_lm_head=optimize_lm_head,
 | 
					                optimize_lm_head=optimize_lm_head,
 | 
				
			||||||
                enable_scale_search=enable_scale_search,
 | 
					                enable_scale_search=enable_scale_search,
 | 
				
			||||||
 | 
					                conver_to_half=False,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
    return new_linear
 | 
					    return new_linear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -589,6 +590,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                            optimize_lm_head=False,
 | 
					                            optimize_lm_head=False,
 | 
				
			||||||
                            act_order=act_order,
 | 
					                            act_order=act_order,
 | 
				
			||||||
                            enable_scale_search=enable_scale_search,
 | 
					                            enable_scale_search=enable_scale_search,
 | 
				
			||||||
 | 
					                            conver_to_half=False,
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                        device = module.qweight.data.device
 | 
					                        device = module.qweight.data.device
 | 
				
			||||||
                        invalidInputError(device.type != "meta",
 | 
					                        invalidInputError(device.type != "meta",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -654,7 +654,7 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    w = self.weight.data
 | 
					                    w = self.weight.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
 | 
					                if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and self.conver_to_half:
 | 
				
			||||||
                    import xe_batch
 | 
					                    import xe_batch
 | 
				
			||||||
                    result = xe_batch.batch_forward(x_2d, w, self.qtype)
 | 
					                    result = xe_batch.batch_forward(x_2d, w, self.qtype)
 | 
				
			||||||
                elif not is_training and self.conver_to_half \
 | 
					                elif not is_training and self.conver_to_half \
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue