[NPU] Further fix saving of generation config (#12657)
* Further fix saving of generation config * Fix based on comments * Small fix
This commit is contained in:
		
							parent
							
								
									381d448ee2
								
							
						
					
					
						commit
						ebdf19fa7e
					
				
					 1 changed files with 5 additions and 8 deletions
				
			
		| 
						 | 
					@ -473,10 +473,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
				
			||||||
                       "n_splits_linear": n_splits_linear,
 | 
					                       "n_splits_linear": n_splits_linear,
 | 
				
			||||||
                       "n_splits_down_proj": n_splits_down_proj,
 | 
					                       "n_splits_down_proj": n_splits_down_proj,
 | 
				
			||||||
                       "lm_head_low_bit": lm_head_low_bit}
 | 
					                       "lm_head_low_bit": lm_head_low_bit}
 | 
				
			||||||
        model.config.update(update_dict)
 | 
					 | 
				
			||||||
        model.config.save_pretrained(save_directory)
 | 
					 | 
				
			||||||
        if model.can_generate():
 | 
					 | 
				
			||||||
            model.generation_config.save_pretrained(save_directory)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from .qwen import convert_qwen_layer, convert_fused_qwen_layer
 | 
					        from .qwen import convert_qwen_layer, convert_fused_qwen_layer
 | 
				
			||||||
        from .qwen import convert_lm_head_and_embedding
 | 
					        from .qwen import convert_lm_head_and_embedding
 | 
				
			||||||
| 
						 | 
					@ -537,8 +533,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
				
			||||||
                       "n_splits_linear": n_splits_linear,
 | 
					                       "n_splits_linear": n_splits_linear,
 | 
				
			||||||
                       "n_splits_down_proj": n_splits_down_proj,
 | 
					                       "n_splits_down_proj": n_splits_down_proj,
 | 
				
			||||||
                       "lm_head_low_bit": lm_head_low_bit}
 | 
					                       "lm_head_low_bit": lm_head_low_bit}
 | 
				
			||||||
        model.config.update(update_dict)
 | 
					 | 
				
			||||||
        model.config.save_pretrained(save_directory)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from .llama import convert_llama_layer, convert_fused_llama_layer
 | 
					        from .llama import convert_llama_layer, convert_fused_llama_layer
 | 
				
			||||||
        from .llama import convert_lm_head_and_embedding
 | 
					        from .llama import convert_lm_head_and_embedding
 | 
				
			||||||
| 
						 | 
					@ -577,8 +571,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
				
			||||||
                       "n_splits_linear": n_splits_linear,
 | 
					                       "n_splits_linear": n_splits_linear,
 | 
				
			||||||
                       "n_splits_down_proj": n_splits_down_proj,
 | 
					                       "n_splits_down_proj": n_splits_down_proj,
 | 
				
			||||||
                       "lm_head_low_bit": lm_head_low_bit}
 | 
					                       "lm_head_low_bit": lm_head_low_bit}
 | 
				
			||||||
        model.config.update(update_dict)
 | 
					 | 
				
			||||||
        model.config.save_pretrained(save_directory)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer
 | 
					        from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer
 | 
				
			||||||
        from .minicpm import convert_lm_head_and_embedding
 | 
					        from .minicpm import convert_lm_head_and_embedding
 | 
				
			||||||
| 
						 | 
					@ -595,3 +587,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
				
			||||||
                                      save_directory, weight_dir,
 | 
					                                      save_directory, weight_dir,
 | 
				
			||||||
                                      convert_model=True,
 | 
					                                      convert_model=True,
 | 
				
			||||||
                                      max_prompt_len=max_prompt_len)
 | 
					                                      max_prompt_len=max_prompt_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.config.update(update_dict)
 | 
				
			||||||
 | 
					    model.config.save_pretrained(save_directory)
 | 
				
			||||||
 | 
					    if model.can_generate():
 | 
				
			||||||
 | 
					        model.generation_config.save_pretrained(save_directory)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue