Update save/load comments (#12500)

This commit is contained in:
Kai Huang 2024-12-04 18:51:38 +08:00 committed by GitHub
parent b89ea1b0cf
commit d8b14a6305
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -429,11 +429,10 @@ class _BaseAutoModelClass:
os.path.join(pretrained_model_name_or_path, "config.json"), os.path.join(pretrained_model_name_or_path, "config.json"),
trust_remote_code=trust_remote_code) trust_remote_code=trust_remote_code)
with torch.device('meta'): with torch.device('meta'):
model = transformers.AutoModelForCausalLM.from_config( model = cls.HF_Model.from_config(
config, trust_remote_code=trust_remote_code) config, trust_remote_code=trust_remote_code)
try: try:
model_ptr = load_model_from_file(pretrained_model_name_or_path) model_ptr = load_model_from_file(pretrained_model_name_or_path)
model.config = config
model.model_ptr = model_ptr model.model_ptr = model_ptr
model.save_directory = pretrained_model_name_or_path model.save_directory = pretrained_model_name_or_path
model.kv_len = config_dict['kv_len'] model.kv_len = config_dict['kv_len']