From f7f4e65788dc1a220f4d1daa2db7ebd11acec604 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Thu, 15 Jun 2023 14:48:21 +0800 Subject: [PATCH] LLM: support int8 and tmp_path for `from_pretrained` (#8338) --- .../llm/src/bigdl/llm/ggml/transformers/model.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/transformers/model.py b/python/llm/src/bigdl/llm/ggml/transformers/model.py index 9225cbeb..43ede645 100644 --- a/python/llm/src/bigdl/llm/ggml/transformers/model.py +++ b/python/llm/src/bigdl/llm/ggml/transformers/model.py @@ -36,6 +36,7 @@ class AutoModelForCausalLM: model_family: str = 'llama', dtype: str = 'int4', cache_dir: str = './', + tmp_path: str = None, **kwargs): """ :param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint @@ -50,10 +51,14 @@ class AutoModelForCausalLM: :param model_family: the model family of the pretrained checkpoint. Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``. - :param dtype: (optional) the data type for weight. Currently we only support ``"int4"`` + :param dtype: Which quantized precision will be converted. + Now only `int4` and `int8` are supported, and `int8` only works for `llama` + and `gptneox`. :param cache_dir: (optional) this parameter will only be used when ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id. It indicates the saving path for the converted low precision model. + :param tmp_path: (optional) Which path to store the intermediate fp16 model during the + conversion process. Default to `None` so that intermediate model will not be saved. :param **kwargs: keyword arguments which will be passed to the model instance :return: a model instance @@ -61,8 +66,8 @@ class AutoModelForCausalLM: invalidInputError(model_family in ['llama', 'gptneox', 'bloom'], "Now we only support model family: 'llama', 'gptneox', 'bloom', " "'{}' is not in the list.".format(model_family)) - invalidInputError(dtype.lower() == 'int4', - "Now we only support int4 as date type for weight") + invalidInputError(dtype.lower() in ['int4', 'int8'], + "Now we only support int4 and int8 as date type for weight") # check whether pretrained_model_name_or_path exists. # if not, it is likely that the user wants to pass in the repo id. @@ -93,7 +98,8 @@ class AutoModelForCausalLM: ggml_model_path = convert_model(input_path=pretrained_model_name_or_path, output_path=cache_dir, model_family=model_family, - dtype=dtype) + dtype=dtype, + tmp_path=tmp_path) if model_family == 'llama': from bigdl.llm.ggml.model.llama import Llama