Support save/load model for hf generate (#12499)
* change dummy model * style * meet review
This commit is contained in:
parent
7d27f134dd
commit
b89ea1b0cf
1 changed files with 24 additions and 11 deletions
|
|
@ -423,23 +423,36 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
if enable_cpp_backend:
|
if enable_cpp_backend:
|
||||||
from .npu_models.npu_llm_cpp import load_model_from_file
|
from .npu_models.npu_llm_cpp import load_model_from_file
|
||||||
from .npu_models.convert import generate
|
from .npu_models.convert import generate, general_convert
|
||||||
dummy_model = torch.nn.Module()
|
from .npu_models.convert import prepare_input_ids, causal_lm_forward
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
os.path.join(pretrained_model_name_or_path, "config.json"),
|
||||||
|
trust_remote_code=trust_remote_code)
|
||||||
|
with torch.device('meta'):
|
||||||
|
model = transformers.AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code)
|
||||||
try:
|
try:
|
||||||
model_ptr = load_model_from_file(pretrained_model_name_or_path)
|
model_ptr = load_model_from_file(pretrained_model_name_or_path)
|
||||||
dummy_model.config = PretrainedConfig.from_dict(config_dict)
|
model.config = config
|
||||||
dummy_model.model_ptr = model_ptr
|
model.model_ptr = model_ptr
|
||||||
dummy_model.save_directory = pretrained_model_name_or_path
|
model.save_directory = pretrained_model_name_or_path
|
||||||
dummy_model.kv_len = config_dict['kv_len']
|
model.kv_len = config_dict['kv_len']
|
||||||
dummy_model.vocab_size = config_dict['vocab_size']
|
model.vocab_size = config_dict['vocab_size']
|
||||||
|
model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
|
||||||
except:
|
except:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
"False to InitLLMPipeline.")
|
"Fail to InitLLMPipeline.")
|
||||||
dummy_model.eval()
|
model.eval()
|
||||||
|
# patch model forward
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
general_convert(model, PreTrainedModel, prepare_input_ids,
|
||||||
|
"prepare_inputs_for_generation")
|
||||||
|
general_convert(model, PreTrainedModel, causal_lm_forward)
|
||||||
# patch generate function
|
# patch generate function
|
||||||
import types
|
import types
|
||||||
dummy_model.generate = types.MethodType(generate, dummy_model)
|
model.original_generate = model.generate
|
||||||
return dummy_model
|
model.generate = types.MethodType(generate, model)
|
||||||
|
return model
|
||||||
|
|
||||||
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
|
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
|
||||||
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
|
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue