LLM: support int8 and tmp_path for from_pretrained (#8338)
This commit is contained in:
parent
b30aa49c4e
commit
f7f4e65788
1 changed files with 10 additions and 4 deletions
|
|
@ -36,6 +36,7 @@ class AutoModelForCausalLM:
|
||||||
model_family: str = 'llama',
|
model_family: str = 'llama',
|
||||||
dtype: str = 'int4',
|
dtype: str = 'int4',
|
||||||
cache_dir: str = './',
|
cache_dir: str = './',
|
||||||
|
tmp_path: str = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
:param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
|
:param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
|
||||||
|
|
@ -50,10 +51,14 @@ class AutoModelForCausalLM:
|
||||||
|
|
||||||
:param model_family: the model family of the pretrained checkpoint.
|
:param model_family: the model family of the pretrained checkpoint.
|
||||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
|
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
|
||||||
:param dtype: (optional) the data type for weight. Currently we only support ``"int4"``
|
:param dtype: Which quantized precision will be converted.
|
||||||
|
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||||
|
and `gptneox`.
|
||||||
:param cache_dir: (optional) this parameter will only be used when
|
:param cache_dir: (optional) this parameter will only be used when
|
||||||
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
||||||
It indicates the saving path for the converted low precision model.
|
It indicates the saving path for the converted low precision model.
|
||||||
|
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
|
||||||
|
conversion process. Default to `None` so that intermediate model will not be saved.
|
||||||
:param **kwargs: keyword arguments which will be passed to the model instance
|
:param **kwargs: keyword arguments which will be passed to the model instance
|
||||||
|
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
|
|
@ -61,8 +66,8 @@ class AutoModelForCausalLM:
|
||||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
|
invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
|
||||||
"Now we only support model family: 'llama', 'gptneox', 'bloom', "
|
"Now we only support model family: 'llama', 'gptneox', 'bloom', "
|
||||||
"'{}' is not in the list.".format(model_family))
|
"'{}' is not in the list.".format(model_family))
|
||||||
invalidInputError(dtype.lower() == 'int4',
|
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||||
"Now we only support int4 as date type for weight")
|
"Now we only support int4 and int8 as date type for weight")
|
||||||
|
|
||||||
# check whether pretrained_model_name_or_path exists.
|
# check whether pretrained_model_name_or_path exists.
|
||||||
# if not, it is likely that the user wants to pass in the repo id.
|
# if not, it is likely that the user wants to pass in the repo id.
|
||||||
|
|
@ -93,7 +98,8 @@ class AutoModelForCausalLM:
|
||||||
ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
|
ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
|
||||||
output_path=cache_dir,
|
output_path=cache_dir,
|
||||||
model_family=model_family,
|
model_family=model_family,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
tmp_path=tmp_path)
|
||||||
|
|
||||||
if model_family == 'llama':
|
if model_family == 'llama':
|
||||||
from bigdl.llm.ggml.model.llama import Llama
|
from bigdl.llm.ggml.model.llama import Llama
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue