Support save/load model for hf generate (#12499)
* change dummy model * style * meet review
This commit is contained in:
		
							parent
							
								
									7d27f134dd
								
							
						
					
					
						commit
						b89ea1b0cf
					
				
					 1 changed files with 24 additions and 11 deletions
				
			
		| 
						 | 
				
			
			@ -423,23 +423,36 @@ class _BaseAutoModelClass:
 | 
			
		|||
 | 
			
		||||
        if enable_cpp_backend:
 | 
			
		||||
            from .npu_models.npu_llm_cpp import load_model_from_file
 | 
			
		||||
            from .npu_models.convert import generate
 | 
			
		||||
            dummy_model = torch.nn.Module()
 | 
			
		||||
            from .npu_models.convert import generate, general_convert
 | 
			
		||||
            from .npu_models.convert import prepare_input_ids, causal_lm_forward
 | 
			
		||||
            config = AutoConfig.from_pretrained(
 | 
			
		||||
                os.path.join(pretrained_model_name_or_path, "config.json"),
 | 
			
		||||
                trust_remote_code=trust_remote_code)
 | 
			
		||||
            with torch.device('meta'):
 | 
			
		||||
                model = transformers.AutoModelForCausalLM.from_config(
 | 
			
		||||
                    config, trust_remote_code=trust_remote_code)
 | 
			
		||||
            try:
 | 
			
		||||
                model_ptr = load_model_from_file(pretrained_model_name_or_path)
 | 
			
		||||
                dummy_model.config = PretrainedConfig.from_dict(config_dict)
 | 
			
		||||
                dummy_model.model_ptr = model_ptr
 | 
			
		||||
                dummy_model.save_directory = pretrained_model_name_or_path
 | 
			
		||||
                dummy_model.kv_len = config_dict['kv_len']
 | 
			
		||||
                dummy_model.vocab_size = config_dict['vocab_size']
 | 
			
		||||
                model.config = config
 | 
			
		||||
                model.model_ptr = model_ptr
 | 
			
		||||
                model.save_directory = pretrained_model_name_or_path
 | 
			
		||||
                model.kv_len = config_dict['kv_len']
 | 
			
		||||
                model.vocab_size = config_dict['vocab_size']
 | 
			
		||||
                model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
 | 
			
		||||
            except:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "False to InitLLMPipeline.")
 | 
			
		||||
            dummy_model.eval()
 | 
			
		||||
                                  "Fail to InitLLMPipeline.")
 | 
			
		||||
            model.eval()
 | 
			
		||||
            # patch model forward
 | 
			
		||||
            from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
            general_convert(model, PreTrainedModel, prepare_input_ids,
 | 
			
		||||
                            "prepare_inputs_for_generation")
 | 
			
		||||
            general_convert(model, PreTrainedModel, causal_lm_forward)
 | 
			
		||||
            # patch generate function
 | 
			
		||||
            import types
 | 
			
		||||
            dummy_model.generate = types.MethodType(generate, dummy_model)
 | 
			
		||||
            return dummy_model
 | 
			
		||||
            model.original_generate = model.generate
 | 
			
		||||
            model.generate = types.MethodType(generate, model)
 | 
			
		||||
            return model
 | 
			
		||||
 | 
			
		||||
        has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
 | 
			
		||||
        has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue