fix npu lm_head cpu condition (#11976)
* fix * fix * fix * fix stype * fix style * fix style
This commit is contained in:
		
							parent
							
								
									60aa1a2c0f
								
							
						
					
					
						commit
						573c20bae6
					
				
					 2 changed files with 15 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -66,6 +66,7 @@ if __name__ == "__main__":
 | 
			
		|||
        intra_pp=args.intra_pp,
 | 
			
		||||
        inter_pp=args.inter_pp,
 | 
			
		||||
        transpose_value_cache=not args.disable_transpose_value_cache,
 | 
			
		||||
        modules_to_not_convert=['vpm', 'resampler']
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,18 +42,31 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
 | 
			
		|||
            from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
 | 
			
		||||
            model.apply(pre_compute_inv_freq)
 | 
			
		||||
 | 
			
		||||
    # MiniCPM-V 2.6 and minicpm-2b must put lm_head on CPU now
 | 
			
		||||
    cpu_lm_head = (
 | 
			
		||||
        (model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and
 | 
			
		||||
         model.config.vocab_size == 151666)
 | 
			
		||||
        or (
 | 
			
		||||
            model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40
 | 
			
		||||
        )
 | 
			
		||||
        or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
 | 
			
		||||
        # MiniCPM-V
 | 
			
		||||
        if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
 | 
			
		||||
            # MiniCPM-V 2
 | 
			
		||||
            model.llm.config.model_type = "minicpm"
 | 
			
		||||
        elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
			
		||||
            # MiniCPM-V 2.6
 | 
			
		||||
            model.llm.config.model_type = "qwen2"
 | 
			
		||||
        elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
 | 
			
		||||
            # MiniCPM-V 2.5
 | 
			
		||||
            model.llm.config.model_type = "llama"
 | 
			
		||||
        model = model.llm
 | 
			
		||||
 | 
			
		||||
    # lm_head to cpu optimization
 | 
			
		||||
    if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0":
 | 
			
		||||
    if cpu_lm_head:
 | 
			
		||||
        # disable the optimization by default
 | 
			
		||||
        from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8
 | 
			
		||||
        if qtype == "sym_int4_rtn":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue