LLM: update convert_model to support int8 (#8326)
* update example and convert_model for int8 * reset example * fix style
This commit is contained in:
parent
f64e703083
commit
5094970175
1 changed files with 10 additions and 3 deletions
|
|
@ -41,7 +41,8 @@ def convert_model(input_path: str,
|
||||||
:param model_family: Which model family your input model belongs to.
|
:param model_family: Which model family your input model belongs to.
|
||||||
Now only `llama`/`bloom`/`gptneox` are supported.
|
Now only `llama`/`bloom`/`gptneox` are supported.
|
||||||
:param dtype: Which quantized precision will be converted.
|
:param dtype: Which quantized precision will be converted.
|
||||||
Now only int4 is supported.
|
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||||
|
and `gptneox`.
|
||||||
:param tmp_path: Which path to store the intermediate model during the conversion process.
|
:param tmp_path: Which path to store the intermediate model during the conversion process.
|
||||||
Default to `None` so that intermediate model will not be saved.
|
Default to `None` so that intermediate model will not be saved.
|
||||||
|
|
||||||
|
|
@ -58,8 +59,8 @@ def convert_model(input_path: str,
|
||||||
"{} is not in the list.".format(model_family))
|
"{} is not in the list.".format(model_family))
|
||||||
invalidInputError(os.path.isdir(output_path),
|
invalidInputError(os.path.isdir(output_path),
|
||||||
"The output_path {} was not a directory".format(output_path))
|
"The output_path {} was not a directory".format(output_path))
|
||||||
invalidInputError(dtype == 'int4',
|
invalidInputError(dtype in ['int4', 'int8'],
|
||||||
"Now only int4 is supported.")
|
"Now only int4 and int8 are supported.")
|
||||||
# check for input_path
|
# check for input_path
|
||||||
invalidInputError(os.path.exists(input_path),
|
invalidInputError(os.path.exists(input_path),
|
||||||
"The input path {} was not found".format(input_path))
|
"The input path {} was not found".format(input_path))
|
||||||
|
|
@ -69,6 +70,12 @@ def convert_model(input_path: str,
|
||||||
|
|
||||||
if dtype == 'int4':
|
if dtype == 'int4':
|
||||||
dtype = 'q4_0'
|
dtype = 'q4_0'
|
||||||
|
elif dtype == 'int8':
|
||||||
|
dtype = 'q8_0'
|
||||||
|
invalidInputError(model_family in ['llama', 'gptneox'],
|
||||||
|
"Now we only support int8 quantization of model \
|
||||||
|
family('llama', 'gptneox')",
|
||||||
|
"{} is not in the list.".format(model_family))
|
||||||
|
|
||||||
if tmp_path is not None:
|
if tmp_path is not None:
|
||||||
model_name = Path(input_path).stem
|
model_name = Path(input_path).stem
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue