LLM: support int8 and tmp_path for from_pretrained (#8338)

This commit is contained in:
Ruonan Wang 2023-06-15 14:48:21 +08:00 committed by GitHub
parent b30aa49c4e
commit f7f4e65788

View file

@ -36,6 +36,7 @@ class AutoModelForCausalLM:
model_family: str = 'llama',
dtype: str = 'int4',
cache_dir: str = './',
tmp_path: str = None,
**kwargs):
"""
:param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
@ -50,10 +51,14 @@ class AutoModelForCausalLM:
:param model_family: the model family of the pretrained checkpoint.
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
:param dtype: (optional) the data type for weight. Currently we only support ``"int4"``
:param dtype: Which quantized precision will be converted.
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
and `gptneox`.
:param cache_dir: (optional) this parameter will only be used when
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
It indicates the saving path for the converted low precision model.
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
conversion process. Default to `None` so that intermediate model will not be saved.
:param **kwargs: keyword arguments which will be passed to the model instance
:return: a model instance
@ -61,8 +66,8 @@ class AutoModelForCausalLM:
invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
"Now we only support model family: 'llama', 'gptneox', 'bloom', "
"'{}' is not in the list.".format(model_family))
invalidInputError(dtype.lower() == 'int4',
"Now we only support int4 as date type for weight")
invalidInputError(dtype.lower() in ['int4', 'int8'],
"Now we only support int4 and int8 as date type for weight")
# check whether pretrained_model_name_or_path exists.
# if not, it is likely that the user wants to pass in the repo id.
@ -93,7 +98,8 @@ class AutoModelForCausalLM:
ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
output_path=cache_dir,
model_family=model_family,
dtype=dtype)
dtype=dtype,
tmp_path=tmp_path)
if model_family == 'llama':
from bigdl.llm.ggml.model.llama import Llama