LLM: update convert_model to support int8 (#8326)

* update example and convert_model for int8

* reset example

* fix style
This commit is contained in:
Ruonan Wang 2023-06-15 09:25:07 +08:00 committed by GitHub
parent f64e703083
commit 5094970175

View file

@ -41,7 +41,8 @@ def convert_model(input_path: str,
:param model_family: Which model family your input model belongs to.
Now only `llama`/`bloom`/`gptneox` are supported.
:param dtype: Which quantized precision will be converted.
Now only int4 is supported.
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
and `gptneox`.
:param tmp_path: Which path to store the intermediate model during the conversion process.
Default to `None` so that intermediate model will not be saved.
@ -58,8 +59,8 @@ def convert_model(input_path: str,
"{} is not in the list.".format(model_family))
invalidInputError(os.path.isdir(output_path),
"The output_path {} was not a directory".format(output_path))
invalidInputError(dtype == 'int4',
"Now only int4 is supported.")
invalidInputError(dtype in ['int4', 'int8'],
"Now only int4 and int8 are supported.")
# check for input_path
invalidInputError(os.path.exists(input_path),
"The input path {} was not found".format(input_path))
@ -69,6 +70,12 @@ def convert_model(input_path: str,
if dtype == 'int4':
dtype = 'q4_0'
elif dtype == 'int8':
dtype = 'q8_0'
invalidInputError(model_family in ['llama', 'gptneox'],
"Now we only support int8 quantization of model \
family('llama', 'gptneox')",
"{} is not in the list.".format(model_family))
if tmp_path is not None:
model_name = Path(input_path).stem