LLM: support models hosted by modelscope (#10106)
This commit is contained in:
parent
1710ecb990
commit
925f82107e
2 changed files with 31 additions and 2 deletions
|
|
@ -124,11 +124,24 @@ class _BaseAutoModelClass:
|
|||
:param imatrix: str value, represent filename of importance matrix pretrained on
|
||||
specific datasets for use with the improved quantization methods recently
|
||||
added to llama.cpp.
|
||||
:param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
|
||||
specify the model hub. Default to be ``'huggingface'``.
|
||||
:return: a model instance
|
||||
"""
|
||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||
if len(args) == 0 else args[0]
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||
model_hub = kwargs.pop("model_hub", "huggingface")
|
||||
invalidInputError(model_hub in ["huggingface", "modelscope"],
|
||||
"The parameter `model_hub` is supposed to be `huggingface` or "
|
||||
f"`modelscope`, but got {model_hub}.")
|
||||
if model_hub == "huggingface":
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||
elif model_hub == "modelscope":
|
||||
import modelscope
|
||||
from modelscope.utils.hf_util import get_wrapped_class
|
||||
cls.HF_Model = get_wrapped_class(cls.HF_Model)
|
||||
from .utils import get_modelscope_hf_config
|
||||
config_dict, _ = get_modelscope_hf_config(pretrained_model_name_or_path)
|
||||
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||
invalidInputError(not bigdl_transformers_low_bit,
|
||||
f"Detected model is a low-bit({bigdl_transformers_low_bit}) model, "
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ import os
|
|||
from transformers.modeling_utils import _add_variant
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from ..utils.common import invalidInputError
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
import logging
|
||||
|
|
@ -258,3 +258,19 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
|||
cur_qtype = qtype
|
||||
|
||||
return cur_qtype, cur_imatrix
|
||||
|
||||
|
||||
def get_modelscope_hf_config(model_id_or_path: str,
|
||||
revision: Optional[str] = None):
|
||||
# Read hf config dictionary from modelscope hub or local path
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.utils.config import Config
|
||||
if not os.path.exists(model_id_or_path):
|
||||
local_path = model_file_download(
|
||||
model_id_or_path, ModelFile.CONFIG, revision=revision)
|
||||
elif os.path.isdir(model_id_or_path):
|
||||
local_path = os.path.join(model_id_or_path, ModelFile.CONFIG)
|
||||
elif os.path.isfile(model_id_or_path):
|
||||
local_path = model_id_or_path
|
||||
return Config._file2dict(local_path)
|
||||
|
|
|
|||
Loading…
Reference in a new issue