Support save/load model for hf generate (#12499)
* change dummy model * style * meet review
This commit is contained in:
parent
7d27f134dd
commit
b89ea1b0cf
1 changed files with 24 additions and 11 deletions
|
|
@ -423,23 +423,36 @@ class _BaseAutoModelClass:
|
|||
|
||||
if enable_cpp_backend:
|
||||
from .npu_models.npu_llm_cpp import load_model_from_file
|
||||
from .npu_models.convert import generate
|
||||
dummy_model = torch.nn.Module()
|
||||
from .npu_models.convert import generate, general_convert
|
||||
from .npu_models.convert import prepare_input_ids, causal_lm_forward
|
||||
config = AutoConfig.from_pretrained(
|
||||
os.path.join(pretrained_model_name_or_path, "config.json"),
|
||||
trust_remote_code=trust_remote_code)
|
||||
with torch.device('meta'):
|
||||
model = transformers.AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code)
|
||||
try:
|
||||
model_ptr = load_model_from_file(pretrained_model_name_or_path)
|
||||
dummy_model.config = PretrainedConfig.from_dict(config_dict)
|
||||
dummy_model.model_ptr = model_ptr
|
||||
dummy_model.save_directory = pretrained_model_name_or_path
|
||||
dummy_model.kv_len = config_dict['kv_len']
|
||||
dummy_model.vocab_size = config_dict['vocab_size']
|
||||
model.config = config
|
||||
model.model_ptr = model_ptr
|
||||
model.save_directory = pretrained_model_name_or_path
|
||||
model.kv_len = config_dict['kv_len']
|
||||
model.vocab_size = config_dict['vocab_size']
|
||||
model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
|
||||
except:
|
||||
invalidInputError(False,
|
||||
"False to InitLLMPipeline.")
|
||||
dummy_model.eval()
|
||||
"Fail to InitLLMPipeline.")
|
||||
model.eval()
|
||||
# patch model forward
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
general_convert(model, PreTrainedModel, prepare_input_ids,
|
||||
"prepare_inputs_for_generation")
|
||||
general_convert(model, PreTrainedModel, causal_lm_forward)
|
||||
# patch generate function
|
||||
import types
|
||||
dummy_model.generate = types.MethodType(generate, dummy_model)
|
||||
return dummy_model
|
||||
model.original_generate = model.generate
|
||||
model.generate = types.MethodType(generate, model)
|
||||
return model
|
||||
|
||||
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
|
||||
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
|
||||
|
|
|
|||
Loading…
Reference in a new issue