LLM: support models hosted by modelscope (#10106)
This commit is contained in:
parent
1710ecb990
commit
925f82107e
2 changed files with 31 additions and 2 deletions
|
|
@ -124,11 +124,24 @@ class _BaseAutoModelClass:
|
||||||
:param imatrix: str value, represent filename of importance matrix pretrained on
|
:param imatrix: str value, represent filename of importance matrix pretrained on
|
||||||
specific datasets for use with the improved quantization methods recently
|
specific datasets for use with the improved quantization methods recently
|
||||||
added to llama.cpp.
|
added to llama.cpp.
|
||||||
|
:param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
|
||||||
|
specify the model hub. Default to be ``'huggingface'``.
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||||
if len(args) == 0 else args[0]
|
if len(args) == 0 else args[0]
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
model_hub = kwargs.pop("model_hub", "huggingface")
|
||||||
|
invalidInputError(model_hub in ["huggingface", "modelscope"],
|
||||||
|
"The parameter `model_hub` is supposed to be `huggingface` or "
|
||||||
|
f"`modelscope`, but got {model_hub}.")
|
||||||
|
if model_hub == "huggingface":
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||||
|
elif model_hub == "modelscope":
|
||||||
|
import modelscope
|
||||||
|
from modelscope.utils.hf_util import get_wrapped_class
|
||||||
|
cls.HF_Model = get_wrapped_class(cls.HF_Model)
|
||||||
|
from .utils import get_modelscope_hf_config
|
||||||
|
config_dict, _ = get_modelscope_hf_config(pretrained_model_name_or_path)
|
||||||
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||||
invalidInputError(not bigdl_transformers_low_bit,
|
invalidInputError(not bigdl_transformers_low_bit,
|
||||||
f"Detected model is a low-bit({bigdl_transformers_low_bit}) model, "
|
f"Detected model is a low-bit({bigdl_transformers_low_bit}) model, "
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ import os
|
||||||
from transformers.modeling_utils import _add_variant
|
from transformers.modeling_utils import _add_variant
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from ..utils.common import invalidInputError
|
from ..utils.common import invalidInputError
|
||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -258,3 +258,19 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
||||||
cur_qtype = qtype
|
cur_qtype = qtype
|
||||||
|
|
||||||
return cur_qtype, cur_imatrix
|
return cur_qtype, cur_imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def get_modelscope_hf_config(model_id_or_path: str,
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
# Read hf config dictionary from modelscope hub or local path
|
||||||
|
from modelscope.utils.constant import ModelFile
|
||||||
|
from modelscope.hub.file_download import model_file_download
|
||||||
|
from modelscope.utils.config import Config
|
||||||
|
if not os.path.exists(model_id_or_path):
|
||||||
|
local_path = model_file_download(
|
||||||
|
model_id_or_path, ModelFile.CONFIG, revision=revision)
|
||||||
|
elif os.path.isdir(model_id_or_path):
|
||||||
|
local_path = os.path.join(model_id_or_path, ModelFile.CONFIG)
|
||||||
|
elif os.path.isfile(model_id_or_path):
|
||||||
|
local_path = model_id_or_path
|
||||||
|
return Config._file2dict(local_path)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue