transformer api refactor (#8389)
* transformer api refactor * fix style * add huggingface tokenizer usage in example and make ggml tokenzizer as option 1 and huggingface tokenizer as option 2 * fix style
This commit is contained in:
parent
ce6d06eb0a
commit
446175cc05
6 changed files with 139 additions and 217 deletions
|
|
@ -75,40 +75,46 @@ llm-cli -x llama -h
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Transformers like API
|
#### Transformers like API
|
||||||
Users could load converted model or even the unconverted huggingface model directly by `AutoModelForCausalLM.from_pretrained`.
|
You can also load the converted model using `BigdlForCausalLM` with a transformer like API,
|
||||||
|
```python
|
||||||
|
from bigdl.llm.transformers import BigdlForCausalLM
|
||||||
|
llm = BigdlForCausalLM.from_pretrained("/path/to/llama-7b-int4/bigdl-llm-xxx.bin",
|
||||||
|
model_family="llama")
|
||||||
|
prompt="What is AI?"
|
||||||
|
```
|
||||||
|
and simply do inference end-to-end like
|
||||||
|
```python
|
||||||
|
output = llm(prompt, max_tokens=32)
|
||||||
|
```
|
||||||
|
If you need to seperate the tokenization and generation, you can also do inference like
|
||||||
|
```python
|
||||||
|
tokens_id = llm.tokenize(prompt)
|
||||||
|
output_tokens_id = llm.generate(tokens_id, max_new_tokens=32)
|
||||||
|
output = llm.batch_decode(output_tokens_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Alternatively, you can load huggingface model directly using `AutoModelForCausalLM.from_pretrained`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from bigdl.llm.ggml.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
# option 1: load converted model
|
# option 1: load huggingface checkpoint
|
||||||
llm = AutoModelForCausalLM.from_pretrained("/path/to/llama-7b-int4/bigdl-llm-xxx.bin",
|
|
||||||
model_family="llama")
|
|
||||||
|
|
||||||
# option 2: load huggingface checkpoint
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained("/path/to/llama-7b-hf/",
|
llm = AutoModelForCausalLM.from_pretrained("/path/to/llama-7b-hf/",
|
||||||
model_family="llama")
|
model_family="llama")
|
||||||
|
|
||||||
# option 3: load from huggingface hub repo
|
# option 2: load from huggingface hub repo
|
||||||
llm = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf",
|
llm = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf",
|
||||||
model_family="llama")
|
model_family="llama")
|
||||||
```
|
```
|
||||||
|
|
||||||
Users could use llm to do the inference. Apart from end-to-end fast forward, we also support split the tokenization and model inference in our API.
|
You can then use the the model the same way as you use transformers.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# end-to-end fast forward w/o spliting the tokenization and model inferencing
|
|
||||||
result = llm("what is ai")
|
|
||||||
|
|
||||||
# Use transformers tokenizer
|
# Use transformers tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
||||||
tokens = tokenizer("what is ai").input_ids
|
tokens = tokenizer("what is ai").input_ids
|
||||||
tokens_id = llm.generate(tokens, max_new_tokens=32)
|
tokens_id = llm.generate(tokens, max_new_tokens=32)
|
||||||
tokenizer.batch_decode(tokens_id)
|
tokenizer.batch_decode(tokens_id)
|
||||||
|
|
||||||
# Use bigdl-llm tokenizer
|
|
||||||
tokens = llm.tokenize("what is ai")
|
|
||||||
tokens_id = llm.generate(tokens, max_new_tokens=32)
|
|
||||||
decoded = llm.batch_decode(tokens_id)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### llama-cpp-python like API
|
#### llama-cpp-python like API
|
||||||
|
|
|
||||||
|
|
@ -18,54 +18,49 @@ import time
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def convert_and_load(repo_id_or_model_path, model_family, n_threads):
|
def convert(repo_id_or_model_path, model_family, tmp_path):
|
||||||
|
from bigdl.llm import llm_convert
|
||||||
|
original_llm_path = repo_id_or_model_path
|
||||||
|
bigdl_llm_path = llm_convert(
|
||||||
|
model=original_llm_path,
|
||||||
|
outfile='./',
|
||||||
|
outtype='int4',
|
||||||
|
tmp_path=tmp_path,
|
||||||
|
model_family=model_family)
|
||||||
|
|
||||||
from bigdl.llm.ggml.transformers import AutoModelForCausalLM
|
return bigdl_llm_path
|
||||||
|
|
||||||
# here you may input the HuggingFace repo id directly as the value of `pretrained_model_name_or_path`.
|
def load(model_path, model_family, n_threads):
|
||||||
# This will allow the pre-trained model to be downloaded directly from the HuggingFace repository.
|
from bigdl.llm.transformers import BigdlForCausalLM
|
||||||
# The downloaded model will then be converted to binary format with int4 dtype weights,
|
llm = BigdlForCausalLM.from_pretrained(
|
||||||
# and saved into the cache_dir folder.
|
pretrained_model_name_or_path=model_path,
|
||||||
#
|
|
||||||
# if you already have the pre-trained model downloaded, you can provide the path to
|
|
||||||
# the downloaded folder as the value of `pretrained_model_name_or_path``
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=repo_id_or_model_path,
|
|
||||||
model_family=model_family,
|
model_family=model_family,
|
||||||
dtype='int4',
|
|
||||||
cache_dir='./',
|
|
||||||
n_threads=n_threads)
|
n_threads=n_threads)
|
||||||
|
|
||||||
# if you want to explicitly convert the pre-trained model, you can use the `llm_convert` API
|
|
||||||
# to convert the downloaded Huggungface checkpoint first,
|
|
||||||
# and then load the binary checkpoint directly.
|
|
||||||
#
|
|
||||||
# from bigdl.llm import llm_convert
|
|
||||||
#
|
|
||||||
# model_path = repo_id_or_model_path
|
|
||||||
# output_ckpt_path = llm_convert(
|
|
||||||
# model=model_path,
|
|
||||||
# outfile='./',
|
|
||||||
# outtype='int4',
|
|
||||||
# model_family=model_family)
|
|
||||||
#
|
|
||||||
# llm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
# pretrained_model_name_or_path=output_ckpt_path,
|
|
||||||
# model_family=model_family,
|
|
||||||
# n_threads=n_threads)
|
|
||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
def inference(llm, repo_id_or_model_path, model_family, prompt):
|
def inference(llm, repo_id_or_model_path, model_family, prompt):
|
||||||
|
|
||||||
if model_family in ['llama', 'gptneox']:
|
if model_family in ['llama', 'gptneox']:
|
||||||
# Option 1: Use HuggingFace transformers tokenizer
|
# ------ Option 1: Use bigdl-llm based tokenizer
|
||||||
|
print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
|
||||||
|
st = time.time()
|
||||||
|
|
||||||
|
# please note that the prompt here can either be a string or a list of string
|
||||||
|
tokens_id = llm.tokenize(prompt)
|
||||||
|
output_tokens_id = llm.generate(tokens_id, max_new_tokens=32)
|
||||||
|
output = llm.batch_decode(output_tokens_id)
|
||||||
|
|
||||||
|
print(f'Inference time: {time.time()-st} s')
|
||||||
|
print(f'Output:\n{output}')
|
||||||
|
|
||||||
|
# ------- Option 2: Use HuggingFace transformers tokenizer
|
||||||
print('-'*20, ' HuggingFace transformers tokenizer ', '-'*20)
|
print('-'*20, ' HuggingFace transformers tokenizer ', '-'*20)
|
||||||
|
|
||||||
print('Please note that the loading of HuggingFace transformers tokenizer may take some time.\n')
|
print('Please note that the loading of HuggingFace transformers tokenizer may take some time.\n')
|
||||||
# here is only a workaround for default example model 'decapoda-research/llama-7b-hf' in LLaMA family,
|
# here is only a workaround for default example model 'decapoda-research/llama-7b-hf' in LLaMA family,
|
||||||
# due to its out-of-date 'tokenizer_class' defined in its tokenizer_config.json.
|
# due to its out-of-date 'tokenizer_class' defined in its tokenizer_config.json.
|
||||||
#
|
|
||||||
# for most cases, you could use `AutoTokenizer`.
|
# for most cases, you could use `AutoTokenizer`.
|
||||||
if model_family == 'llama':
|
if model_family == 'llama':
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
@ -84,17 +79,6 @@ def inference(llm, repo_id_or_model_path, model_family, prompt):
|
||||||
print(f'Inference time: {time.time()-st} s')
|
print(f'Inference time: {time.time()-st} s')
|
||||||
print(f'Output:\n{output}')
|
print(f'Output:\n{output}')
|
||||||
|
|
||||||
# Option 2: Use bigdl-llm based tokenizer
|
|
||||||
print('-'*20, ' bigdl-llm based tokenizer ', '-'*20)
|
|
||||||
st = time.time()
|
|
||||||
|
|
||||||
# please note that the prompt here can either be a string or a list of string
|
|
||||||
tokens_id = llm.tokenize(prompt)
|
|
||||||
output_tokens_id = llm.generate(tokens_id, max_new_tokens=32)
|
|
||||||
output = llm.batch_decode(output_tokens_id)
|
|
||||||
|
|
||||||
print(f'Inference time: {time.time()-st} s')
|
|
||||||
print(f'Output:\n{output}')
|
|
||||||
|
|
||||||
if model_family in ['llama', 'gptneox', 'bloom']:
|
if model_family in ['llama', 'gptneox', 'bloom']:
|
||||||
# Option 3: fast forward
|
# Option 3: fast forward
|
||||||
|
|
@ -121,6 +105,8 @@ def main():
|
||||||
', or the path to the huggingface checkpoint folder')
|
', or the path to the huggingface checkpoint folder')
|
||||||
parser.add_argument('--prompt', type=str, default='Q: What is CPU? A:',
|
parser.add_argument('--prompt', type=str, default='Q: What is CPU? A:',
|
||||||
help='Prompt to infer')
|
help='Prompt to infer')
|
||||||
|
parser.add_argument('--tmp-path', type=str, default='/tmp',
|
||||||
|
help='path to store intermediate model during the conversion process')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
repo_id_or_model_path = args.repo_id_or_model_path
|
repo_id_or_model_path = args.repo_id_or_model_path
|
||||||
|
|
@ -132,12 +118,18 @@ def main():
|
||||||
elif args.model_family == 'bloom':
|
elif args.model_family == 'bloom':
|
||||||
repo_id_or_model_path = 'bigscience/bloomz-7b1'
|
repo_id_or_model_path = 'bigscience/bloomz-7b1'
|
||||||
|
|
||||||
# Step 1: convert and load int4 model
|
# Step 1: convert original model to BigDL llm model
|
||||||
llm = convert_and_load(repo_id_or_model_path=repo_id_or_model_path,
|
bigdl_llm_path = convert(repo_id_or_model_path=repo_id_or_model_path,
|
||||||
model_family=args.model_family,
|
model_family=args.model_family,
|
||||||
n_threads=args.thread_num)
|
tmp_path=args.tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
# Step 2: load int4 model
|
||||||
|
llm = load(model_path=bigdl_llm_path,
|
||||||
|
model_family=args.model_family,
|
||||||
|
n_threads=args.thread_num)
|
||||||
|
|
||||||
# Step 2: conduct inference
|
# Step 3: inference
|
||||||
inference(llm=llm,
|
inference(llm=llm,
|
||||||
repo_id_or_model_path=repo_id_or_model_path,
|
repo_id_or_model_path=repo_id_or_model_path,
|
||||||
model_family=args.model_family,
|
model_family=args.model_family,
|
||||||
|
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
#
|
|
||||||
# Copyright 2016 The BigDL Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
|
||||||
# physically located elsewhere.
|
|
||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
|
||||||
|
|
||||||
from .model import AutoModelForCausalLM
|
|
||||||
|
|
@ -1,128 +0,0 @@
|
||||||
#
|
|
||||||
# Copyright 2016 The BigDL Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
|
||||||
# physically located elsewhere.
|
|
||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForCausalLM:
|
|
||||||
"""
|
|
||||||
A generic model class that mimics the behavior of
|
|
||||||
``transformers.AutoModelForCausalLM.from_pretrained`` API
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls,
|
|
||||||
pretrained_model_name_or_path: str,
|
|
||||||
model_format: str = 'pth',
|
|
||||||
model_family: str = 'llama',
|
|
||||||
dtype: str = 'int4',
|
|
||||||
cache_dir: str = './',
|
|
||||||
tmp_path: str = None,
|
|
||||||
**kwargs):
|
|
||||||
"""
|
|
||||||
:param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
|
|
||||||
|
|
||||||
1. Path to directory for Hugging Face checkpoint that are directly pulled from
|
|
||||||
Hugging Face hub.
|
|
||||||
|
|
||||||
If ``model_format='pth'``, the folder should contain: weight bin, tokenizer
|
|
||||||
config, tokenizer.model (required for llama) and added_tokens.json (if applied).
|
|
||||||
For lora fine tuned model, the path should be pointed to a merged weight.
|
|
||||||
|
|
||||||
If ``model_format='gptq'``, the folder should be be a Hugging Face checkpoint
|
|
||||||
in GPTQ format, which contains weights in pytorch's .pt format,
|
|
||||||
and ``tokenizer.model``.
|
|
||||||
|
|
||||||
2. Path for converted BigDL-LLM optimized ggml binary checkpoint.
|
|
||||||
The checkpoint should be converted by ``bigdl.llm.llm_convert``.
|
|
||||||
3. A str for Hugging Face hub repo id.
|
|
||||||
|
|
||||||
:param model_format: Specify the model format to be converted. ``pth`` is for
|
|
||||||
PyTorch model checkpoint from Hugging Face. ``gptq`` is for GPTQ format
|
|
||||||
model from Hugging Face.
|
|
||||||
:param model_family: The model family of the pretrained checkpoint.
|
|
||||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
|
|
||||||
:param dtype: Which quantized precision will be converted.
|
|
||||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
|
||||||
, `gptneox` and `starcoder`.
|
|
||||||
:param cache_dir: (optional) This parameter will only be used when
|
|
||||||
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
|
||||||
It indicates the saving path for the converted low precision model.
|
|
||||||
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
|
|
||||||
conversion process. Default to `None` so that intermediate model will not be saved.
|
|
||||||
:param **kwargs: keyword arguments which will be passed to the model instance
|
|
||||||
|
|
||||||
:return: a model instance
|
|
||||||
"""
|
|
||||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
|
|
||||||
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
|
||||||
" 'starcoder', '{}' is not in the list.".format(model_family))
|
|
||||||
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
|
||||||
"Now we only support int4 and int8 as date type for weight")
|
|
||||||
|
|
||||||
# check whether pretrained_model_name_or_path exists.
|
|
||||||
# if not, it is likely that the user wants to pass in the repo id.
|
|
||||||
if not os.path.exists(pretrained_model_name_or_path):
|
|
||||||
try:
|
|
||||||
# download from Hugging Face based on repo id
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
pretrained_model_name_or_path = snapshot_download(
|
|
||||||
repo_id=pretrained_model_name_or_path)
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
# if downloading fails, it could be the case that repo id is invalid,
|
|
||||||
# or the user pass in the wrong path for checkpoint
|
|
||||||
invalidInputError(False,
|
|
||||||
"Downloadng from Hugging Face repo id {} failed. "
|
|
||||||
"Please input valid Hugging Face hub repo id, "
|
|
||||||
"or provide the valid path to Hugging Face / "
|
|
||||||
"BigDL-LLM optimized ggml binary checkpoint, "
|
|
||||||
"for pretrained_model_name_or_path"
|
|
||||||
.format(pretrained_model_name_or_path))
|
|
||||||
|
|
||||||
ggml_model_path = pretrained_model_name_or_path
|
|
||||||
# check whether pretrained_model_name_or_path is a file.
|
|
||||||
# if not, it is likely that pretrained_model_name_or_path
|
|
||||||
# points to a Hugging Face checkpoint
|
|
||||||
if not os.path.isfile(pretrained_model_name_or_path):
|
|
||||||
# Hugging Face checkpoint
|
|
||||||
from bigdl.llm import llm_convert
|
|
||||||
ggml_model_path = llm_convert(model=pretrained_model_name_or_path,
|
|
||||||
outfile=cache_dir,
|
|
||||||
model_family=model_family,
|
|
||||||
outtype=dtype,
|
|
||||||
model_format=model_format,
|
|
||||||
tmp_path=tmp_path)
|
|
||||||
|
|
||||||
if model_family == 'llama':
|
|
||||||
from bigdl.llm.ggml.model.llama import Llama
|
|
||||||
return Llama(model_path=ggml_model_path, **kwargs)
|
|
||||||
elif model_family == 'gptneox':
|
|
||||||
from bigdl.llm.ggml.model.gptneox import Gptneox
|
|
||||||
return Gptneox(model_path=ggml_model_path, **kwargs)
|
|
||||||
elif model_family == 'bloom':
|
|
||||||
from bigdl.llm.ggml.model.bloom import Bloom
|
|
||||||
return Bloom(model_path=ggml_model_path, **kwargs)
|
|
||||||
elif model_family == 'starcoder':
|
|
||||||
from bigdl.llm.ggml.model.starcoder import Starcoder
|
|
||||||
return Starcoder(model_path=ggml_model_path, **kwargs)
|
|
||||||
|
|
@ -16,3 +16,4 @@
|
||||||
|
|
||||||
from .convert import ggml_convert_int4
|
from .convert import ggml_convert_int4
|
||||||
from .model import AutoModelForCausalLM, AutoModel
|
from .model import AutoModelForCausalLM, AutoModel
|
||||||
|
from .modelling_bigdl import BigdlForCausalLM
|
||||||
|
|
|
||||||
73
python/llm/src/bigdl/llm/transformers/modelling_bigdl.py
Normal file
73
python/llm/src/bigdl/llm/transformers/modelling_bigdl.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
class BigdlForCausalLM:
|
||||||
|
"""
|
||||||
|
A generic model class that mimics the behavior of
|
||||||
|
``transformers.LlamaForCausalLM.from_pretrained`` API
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
model_family: str = 'llama',
|
||||||
|
dtype: str = 'int4',
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
:param pretrained_model_name_or_path: Path for converted BigDL-LLM optimized ggml
|
||||||
|
binary checkpoint. The checkpoint should be converted by ``bigdl.llm.llm_convert``.
|
||||||
|
:param model_family: The model family of the pretrained checkpoint.
|
||||||
|
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
|
||||||
|
:param dtype: Which quantized precision will be converted.
|
||||||
|
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||||
|
, `gptneox` and `starcoder`.
|
||||||
|
:param cache_dir: (optional) This parameter will only be used when
|
||||||
|
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
||||||
|
It indicates the saving path for the converted low precision model.
|
||||||
|
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
|
||||||
|
conversion process. Default to `None` so that intermediate model will not be saved.
|
||||||
|
:param **kwargs: keyword arguments which will be passed to the model instance
|
||||||
|
|
||||||
|
:return: a model instance
|
||||||
|
"""
|
||||||
|
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
|
||||||
|
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
||||||
|
" 'starcoder', '{}' is not in the list.".format(model_family))
|
||||||
|
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||||
|
"Now we only support int4 and int8 as date type for weight")
|
||||||
|
|
||||||
|
ggml_model_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
|
if model_family == 'llama':
|
||||||
|
from bigdl.llm.ggml.model.llama import Llama
|
||||||
|
return Llama(model_path=ggml_model_path, **kwargs)
|
||||||
|
elif model_family == 'gptneox':
|
||||||
|
from bigdl.llm.ggml.model.gptneox import Gptneox
|
||||||
|
return Gptneox(model_path=ggml_model_path, **kwargs)
|
||||||
|
elif model_family == 'bloom':
|
||||||
|
from bigdl.llm.ggml.model.bloom import Bloom
|
||||||
|
return Bloom(model_path=ggml_model_path, **kwargs)
|
||||||
|
elif model_family == 'starcoder':
|
||||||
|
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||||
|
return Starcoder(model_path=ggml_model_path, **kwargs)
|
||||||
Loading…
Reference in a new issue