Fix abnormal output for Qwen2-7B when sym_int8 (#12446)
This commit is contained in:
		
							parent
							
								
									71e1f11aa6
								
							
						
					
					
						commit
						303b104c10
					
				
					 1 changed files with 5 additions and 1 deletions
				
			
		| 
						 | 
					@ -128,7 +128,11 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.common import split_linears
 | 
					        from ipex_llm.transformers.npu_models.common import split_linears
 | 
				
			||||||
        if quantization_group_size == 0:
 | 
					        if quantization_group_size == 0:
 | 
				
			||||||
            n_splits_linear = 1
 | 
					            n_splits_linear = 1
 | 
				
			||||||
            n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
 | 
					            if qtype == "sym_int8_rtn":
 | 
				
			||||||
 | 
					                # do not split mlp down_proj for Qwen2-7B & sym_int8
 | 
				
			||||||
 | 
					                n_splits_down_proj = 1
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(
 | 
					            invalidInputError(
 | 
				
			||||||
                model.config.hidden_size % quantization_group_size == 0 and
 | 
					                model.config.hidden_size % quantization_group_size == 0 and
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue