[LLM] Add transformers-like API from_pretrained (#8271)
* Init commit for bigdl.llm.transformers.AutoModelForCausalLM * Temp change to avoid name conflicts with external transformers lib * Support downloading model from huggingface * Small python style fix * Change location of transformers to avoid library conflicts * Add return value for converted ggml binary ckpt path for convert_model * Avoid repeated loading of shared library and adding some comments * Small fix * Path type fix anddocstring fix * Small fix * Small fix * Change cache dir to pwd
This commit is contained in:
parent
2ed5842448
commit
64bc123dd3
4 changed files with 137 additions and 4 deletions
|
|
@ -39,6 +39,8 @@ def convert_model(input_path: str,
|
||||||
:param dtype: Which quantized precision will be converted.
|
:param dtype: Which quantized precision will be converted.
|
||||||
Now only int4 supported.
|
Now only int4 supported.
|
||||||
:param tmp_path: Which path to store the intermediate model during the conversion process.
|
:param tmp_path: Which path to store the intermediate model during the conversion process.
|
||||||
|
|
||||||
|
:return: the path str to the converted lower precision checkpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dtype = dtype.lower()
|
dtype = dtype.lower()
|
||||||
|
|
@ -54,7 +56,7 @@ def convert_model(input_path: str,
|
||||||
|
|
||||||
tmp_ggml_file_path = next(Path(tmp_ggml_file_path).iterdir())
|
tmp_ggml_file_path = next(Path(tmp_ggml_file_path).iterdir())
|
||||||
|
|
||||||
quantize(input_path=tmp_ggml_file_path,
|
return quantize(input_path=tmp_ggml_file_path,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
model_family=model_family,
|
model_family=model_family,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,8 @@ def quantize(input_path: str, output_path: str=None,
|
||||||
llama : "q4_0", "q4_1", "q4_2"
|
llama : "q4_0", "q4_1", "q4_2"
|
||||||
bloom : "q4_0", "q4_1"
|
bloom : "q4_0", "q4_1"
|
||||||
gptneox : "q4_0", "q4_1", "q4_2", "q5_0", "q5_1", "q8_0"
|
gptneox : "q4_0", "q4_1", "q4_2", "q5_0", "q5_1", "q8_0"
|
||||||
|
|
||||||
|
:return: the path str to the converted ggml binary checkpoint
|
||||||
"""
|
"""
|
||||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
||||||
"Now we only support quantization of model \
|
"Now we only support quantization of model \
|
||||||
|
|
@ -92,3 +94,4 @@ def quantize(input_path: str, output_path: str=None,
|
||||||
p.communicate()
|
p.communicate()
|
||||||
invalidInputError(not p.returncode,
|
invalidInputError(not p.returncode,
|
||||||
"Fail to quantize {}.".format(str(input_path)))
|
"Fail to quantize {}.".format(str(input_path)))
|
||||||
|
return str(output_path)
|
||||||
|
|
|
||||||
22
python/llm/src/bigdl/llm/ggml/transformers/__init__.py
Normal file
22
python/llm/src/bigdl/llm/ggml/transformers/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
from .model import AutoModelForCausalLM
|
||||||
106
python/llm/src/bigdl/llm/ggml/transformers/model.py
Normal file
106
python/llm/src/bigdl/llm/ggml/transformers/model.py
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from bigdl.llm.ggml import convert_model
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForCausalLM:
|
||||||
|
"""
|
||||||
|
A generic model class that mimics the behavior of
|
||||||
|
``transformers.AutoModelForCausalLM.from_pretrained`` API
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
model_family: str = 'llama',
|
||||||
|
dtype: str = 'int4',
|
||||||
|
cache_dir: str = './',
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
:param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
|
||||||
|
|
||||||
|
1. path for huggingface checkpoint that are directly pulled from hugginface hub.
|
||||||
|
This should be a dir path that contains: weight bin, tokenizer config,
|
||||||
|
tokenizer.model (required for llama) and added_tokens.json (if applied).
|
||||||
|
For lora fine tuned model, the path should be pointed to a merged weight.
|
||||||
|
2. path for converted ggml binary checkpoint. The checkpoint should be converted by
|
||||||
|
``bigdl.llm.ggml.convert_model``.
|
||||||
|
3. a str for huggingface hub repo id.
|
||||||
|
|
||||||
|
:param model_family: the model family of the pretrained checkpoint.
|
||||||
|
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
|
||||||
|
:param dtype: (optional) the data type for weight. Currently we only support ``"int4"``
|
||||||
|
:param cache_dir: (optional) this parameter will only be used when
|
||||||
|
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
||||||
|
It indicates the saving path for the converted low precision model.
|
||||||
|
:param **kwargs: keyword arguments which will be passed to the model instance
|
||||||
|
|
||||||
|
:return: a model instance
|
||||||
|
"""
|
||||||
|
invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
|
||||||
|
"Now we only support model family: 'llama', 'gptneox', 'bloom', "
|
||||||
|
"'{}' is not in the list.".format(model_family))
|
||||||
|
invalidInputError(dtype.lower() == 'int4',
|
||||||
|
"Now we only support int4 as date type for weight")
|
||||||
|
|
||||||
|
# check whether pretrained_model_name_or_path exists.
|
||||||
|
# if not, it is likely that the user wants to pass in the repo id.
|
||||||
|
if not os.path.exists(pretrained_model_name_or_path):
|
||||||
|
try:
|
||||||
|
# download from huggingface based on repo id
|
||||||
|
pretrained_model_name_or_path = snapshot_download(
|
||||||
|
repo_id=pretrained_model_name_or_path)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
# if downloading fails, it could be the case that repo id is invalid,
|
||||||
|
# or the user pass in the wrong path for checkpoint
|
||||||
|
invalidInputError(False,
|
||||||
|
"Downloadng from huggingface repo id {} failed. "
|
||||||
|
"Please input valid huggingface hub repo id, "
|
||||||
|
"or provide the valid path to huggingface / "
|
||||||
|
"ggml binary checkpoint, for pretrained_model_name_or_path"
|
||||||
|
.format(pretrained_model_name_or_path))
|
||||||
|
|
||||||
|
ggml_model_path = pretrained_model_name_or_path
|
||||||
|
# check whether pretrained_model_name_or_path is a file.
|
||||||
|
# if not, it is likely that pretrained_model_name_or_path
|
||||||
|
# points to a huggingface checkpoint
|
||||||
|
if not os.path.isfile(pretrained_model_name_or_path):
|
||||||
|
# huggingface checkpoint
|
||||||
|
ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
|
||||||
|
output_path=cache_dir,
|
||||||
|
model_family=model_family,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
if model_family == 'llama':
|
||||||
|
from bigdl.llm.ggml.model.llama import Llama
|
||||||
|
return Llama(model_path=ggml_model_path, **kwargs)
|
||||||
|
elif model_family == 'gptneox':
|
||||||
|
from bigdl.llm.ggml.model.gptneox import Gptneox
|
||||||
|
return Gptneox(model_path=ggml_model_path, **kwargs)
|
||||||
|
elif model_family == 'bloom':
|
||||||
|
from bigdl.llm.ggml.model.bloom import Bloom
|
||||||
|
return Bloom(model_path=ggml_model_path, **kwargs)
|
||||||
Loading…
Reference in a new issue