Update save/load comments (#12500)
This commit is contained in:
parent
b89ea1b0cf
commit
d8b14a6305
1 changed files with 1 additions and 2 deletions
|
|
@ -429,11 +429,10 @@ class _BaseAutoModelClass:
|
||||||
os.path.join(pretrained_model_name_or_path, "config.json"),
|
os.path.join(pretrained_model_name_or_path, "config.json"),
|
||||||
trust_remote_code=trust_remote_code)
|
trust_remote_code=trust_remote_code)
|
||||||
with torch.device('meta'):
|
with torch.device('meta'):
|
||||||
model = transformers.AutoModelForCausalLM.from_config(
|
model = cls.HF_Model.from_config(
|
||||||
config, trust_remote_code=trust_remote_code)
|
config, trust_remote_code=trust_remote_code)
|
||||||
try:
|
try:
|
||||||
model_ptr = load_model_from_file(pretrained_model_name_or_path)
|
model_ptr = load_model_from_file(pretrained_model_name_or_path)
|
||||||
model.config = config
|
|
||||||
model.model_ptr = model_ptr
|
model.model_ptr = model_ptr
|
||||||
model.save_directory = pretrained_model_name_or_path
|
model.save_directory = pretrained_model_name_or_path
|
||||||
model.kv_len = config_dict['kv_len']
|
model.kv_len = config_dict['kv_len']
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue