Support save/load model for hf generate (#12499)

* change dummy model

* style

* meet review
This commit is contained in:
Kai Huang 2024-12-04 18:26:39 +08:00 committed by GitHub
parent 7d27f134dd
commit b89ea1b0cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -423,23 +423,36 @@ class _BaseAutoModelClass:
if enable_cpp_backend: if enable_cpp_backend:
from .npu_models.npu_llm_cpp import load_model_from_file from .npu_models.npu_llm_cpp import load_model_from_file
from .npu_models.convert import generate from .npu_models.convert import generate, general_convert
dummy_model = torch.nn.Module() from .npu_models.convert import prepare_input_ids, causal_lm_forward
config = AutoConfig.from_pretrained(
os.path.join(pretrained_model_name_or_path, "config.json"),
trust_remote_code=trust_remote_code)
with torch.device('meta'):
model = transformers.AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code)
try: try:
model_ptr = load_model_from_file(pretrained_model_name_or_path) model_ptr = load_model_from_file(pretrained_model_name_or_path)
dummy_model.config = PretrainedConfig.from_dict(config_dict) model.config = config
dummy_model.model_ptr = model_ptr model.model_ptr = model_ptr
dummy_model.save_directory = pretrained_model_name_or_path model.save_directory = pretrained_model_name_or_path
dummy_model.kv_len = config_dict['kv_len'] model.kv_len = config_dict['kv_len']
dummy_model.vocab_size = config_dict['vocab_size'] model.vocab_size = config_dict['vocab_size']
model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
except: except:
invalidInputError(False, invalidInputError(False,
"False to InitLLMPipeline.") "Fail to InitLLMPipeline.")
dummy_model.eval() model.eval()
# patch model forward
from transformers.modeling_utils import PreTrainedModel
general_convert(model, PreTrainedModel, prepare_input_ids,
"prepare_inputs_for_generation")
general_convert(model, PreTrainedModel, causal_lm_forward)
# patch generate function # patch generate function
import types import types
dummy_model.generate = types.MethodType(generate, dummy_model) model.original_generate = model.generate
return dummy_model model.generate = types.MethodType(generate, model)
return model
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
has_local_code = type(config) in cls.HF_Model._model_mapping.keys() has_local_code = type(config) in cls.HF_Model._model_mapping.keys()