diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 65390aea..15605d07 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -423,23 +423,36 @@ class _BaseAutoModelClass: if enable_cpp_backend: from .npu_models.npu_llm_cpp import load_model_from_file - from .npu_models.convert import generate - dummy_model = torch.nn.Module() + from .npu_models.convert import generate, general_convert + from .npu_models.convert import prepare_input_ids, causal_lm_forward + config = AutoConfig.from_pretrained( + os.path.join(pretrained_model_name_or_path, "config.json"), + trust_remote_code=trust_remote_code) + with torch.device('meta'): + model = transformers.AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code) try: model_ptr = load_model_from_file(pretrained_model_name_or_path) - dummy_model.config = PretrainedConfig.from_dict(config_dict) - dummy_model.model_ptr = model_ptr - dummy_model.save_directory = pretrained_model_name_or_path - dummy_model.kv_len = config_dict['kv_len'] - dummy_model.vocab_size = config_dict['vocab_size'] + model.config = config + model.model_ptr = model_ptr + model.save_directory = pretrained_model_name_or_path + model.kv_len = config_dict['kv_len'] + model.vocab_size = config_dict['vocab_size'] + model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32) except: invalidInputError(False, - "False to InitLLMPipeline.") - dummy_model.eval() + "Fail to InitLLMPipeline.") + model.eval() + # patch model forward + from transformers.modeling_utils import PreTrainedModel + general_convert(model, PreTrainedModel, prepare_input_ids, + "prepare_inputs_for_generation") + general_convert(model, PreTrainedModel, causal_lm_forward) # patch generate function import types - dummy_model.generate = types.MethodType(generate, dummy_model) - return dummy_model + model.original_generate = model.generate + model.generate = types.MethodType(generate, model) + return model has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map has_local_code = type(config) in cls.HF_Model._model_mapping.keys()