[NPU ] fix load logic of glm-edge models (#12698)
This commit is contained in:
parent
584c1c5373
commit
da8bcb7db1
1 changed files with 7 additions and 3 deletions
|
|
@ -182,13 +182,17 @@ class _BaseAutoModelClass:
|
||||||
if hasattr(model, "config") and model.config.model_type == "glm":
|
if hasattr(model, "config") and model.config.model_type == "glm":
|
||||||
# convert to llama structure
|
# convert to llama structure
|
||||||
from .npu_models.glm_edge import convert_config, load_weights, convert_state_dict
|
from .npu_models.glm_edge import convert_config, load_weights, convert_state_dict
|
||||||
import json
|
|
||||||
original_path = model.config._name_or_path
|
original_path = model.config._name_or_path
|
||||||
del model
|
del model
|
||||||
|
|
||||||
with open(os.path.join(original_path, "config.json")) as f:
|
original_config, _ = PretrainedConfig.get_config_dict(original_path)
|
||||||
original_config = json.load(f)
|
|
||||||
config = convert_config(original_config)
|
config = convert_config(original_config)
|
||||||
|
|
||||||
|
if not os.path.isdir(original_path):
|
||||||
|
# all model files are already cached
|
||||||
|
from transformers.utils.hub import cached_file
|
||||||
|
resolved_file = cached_file(original_path, "config.json")
|
||||||
|
original_path = os.path.dirname(resolved_file)
|
||||||
original_state_dict = load_weights(original_path)
|
original_state_dict = load_weights(original_path)
|
||||||
new_dict, _ = convert_state_dict(original_state_dict, config,
|
new_dict, _ = convert_state_dict(original_state_dict, config,
|
||||||
original_config.get("partial_rotary_factor", 1.0),
|
original_config.get("partial_rotary_factor", 1.0),
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue