LLM: add chatglm native int4 transformers API (#8695)

This commit is contained in:
binbin Deng 2023-08-07 17:52:47 +08:00 committed by GitHub
parent 6da830cf7e
commit ea5d7aff5b
10 changed files with 94 additions and 19 deletions

View file

@ -124,7 +124,7 @@ def main():
help=("outfile,save path of output quantized model."))
parser.add_argument('-x', '--model-family', type=str, required=True,
help=("--model-family: Which model family your input model belongs to."
"Now only `llama`/`bloom`/`gptneox` are supported."))
"Now only `llama`/`bloom`/`gptneox`/`chatglm` are supported."))
parser.add_argument('-f', '--model-format', type=str, required=True,
help=("The model type to be convert to a ggml compatible file."
"Now only `pth`/`gptq` are supported."))

View file

@ -77,7 +77,7 @@ def _convert_starcoder(model_path, outfile_dir, outtype):
def _convert_chatglm(model_path, outfile_dir, outtype):
_convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
return _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
def _convert_to_ggml(model_path: str, outfile_dir: str,

View file

@ -80,10 +80,9 @@ def convert_model(input_path: str,
# chatglm merges convertion and quantization into one operation.
if model_family == 'chatglm':
_convert_chatglm(model_path=input_path,
outfile_dir=output_path,
outtype=dtype)
return
return _convert_chatglm(model_path=input_path,
outfile_dir=output_path,
outtype=dtype)
if tmp_path is not None:
model_name = Path(input_path).stem

View file

@ -18,3 +18,5 @@
# physically located elsewhere.
# Otherwise there would be module not found error in non-pip's setting as Python would
# only search the first bigdl package and end up finding only one sub-package.
from .chatglm import ChatGLM

View file

@ -56,7 +56,7 @@ import uuid
import warnings
class ChatGLM:
class ChatGLM(GenerationMixin):
"""High-level Python wrapper for a chatglm.cpp model."""
def __init__(
@ -327,7 +327,7 @@ class ChatGLM:
}
}
def _tokenize(self, text: bytes) -> List[int]:
def _tokenize(self, text: bytes, *args) -> List[int]:
"""Tokenize a string.
Args:
@ -339,9 +339,10 @@ class ChatGLM:
Returns:
A list of tokens.
"""
warnings.warn("The parameter `add_bos` is unsupported, please use the default value.")
return chatglm_tokenize(self.ctx, text)
def detokenize(self, tokens: List[int]) -> bytes:
def detokenize(self, tokens: List[int]) -> str:
"""Detokenize a list of tokens.
Args:
@ -371,3 +372,65 @@ class ChatGLM:
def eos_token(self) -> int:
return chatglm_eos_token(self.ctx)
def _generate(
self,
tokens: Sequence[int],
top_k: int = 0,
top_p: float = 0.7,
temp: float = 0.95,
repeat_penalty: float = 1.1,
reset: bool = True,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
Examples:
>>> llm = ChatGLM(your_model_path)
>>> tokens = llm._tokenize(b"Learning English is")
>>> for token in llm._generate(tokens):
>>> print(llm.detokenize([token]).decode("utf-8", errors="ignore"))
Args:
tokens: The prompt tokens.
Yields:
The generated tokens.
"""
# TODO: Some parameters are temporarily not supported
# Unsupported parameters are checked in `_supported_generate`
return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset,
frequency_penalty, presence_penalty, tfs_z, mirostat_mode,
mirostat_tau, mirostat_eta)
def _supported_generate(self, tokens: Sequence[int], top_k: int = 0, top_p: float = 0.7,
temp: float = 0.95, *args):
# Check unsupporeted parameters
unsupported_arg = ['repeat_penalty', 'reset', 'frequency_penalty', 'presence_penalty',
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']
defult_value = {'repeat_penalty': 1.1, 'reset': True, 'frequency_penalty': 0.0,
'presence_penalty': 0.0, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)):
if args[index] != defult_value[unsupported_arg[index]]:
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
n_past = 0
while True:
token = self.forward(input_ids=tokens,
n_past=n_past,
top_k=top_k,
top_p=top_p,
temperature=temp)
n_past += len(tokens)
tokens_or_none = yield token
tokens = [token]
if tokens_or_none is not None:
tokens.extend(tokens_or_none)

View file

@ -72,7 +72,11 @@ class GenerationMixin:
:param tokens: list of ids that indicates the tokens, mostly generated by generate
:return: decoded string
'''
return self.detokenize(tokens).decode()
output = self.detokenize(tokens)
if isinstance(output, str):
return output
else:
return output.decode()
def batch_decode(self,
tokens: Union[List[int], List[List[int]]]) -> str:

View file

@ -23,3 +23,5 @@ from bigdl.llm.ggml.model.llama import Llama
from bigdl.llm.ggml.model.gptneox import Gptneox
from bigdl.llm.ggml.model.bloom import Bloom
from bigdl.llm.ggml.model.starcoder import Starcoder
# temporarily disable until linux binary file for chatglm ready
# from bigdl.llm.ggml.model.chatglm import ChatGLM

View file

@ -38,12 +38,13 @@ class BigdlNativeForCausalLM:
:param pretrained_model_name_or_path: Path for converted BigDL-LLM optimized ggml
binary checkpoint. The checkpoint should be converted by ``bigdl.llm.llm_convert``.
:param model_family: The model family of the pretrained checkpoint.
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``, ``"starcoder"``
and ``"chatglm"``.
:param dtype: Which quantized precision will be converted.
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
, `gptneox` and `starcoder`.
:param cache_dir: (optional) This parameter will only be used when
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
``pretrained_model_name_or_path`` is a huggingface checkpoint or hub repo id.
It indicates the saving path for the converted low precision model.
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
conversion process. Default to `None` so that intermediate model will not be saved.
@ -51,9 +52,9 @@ class BigdlNativeForCausalLM:
:return: a model instance
"""
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
" 'starcoder', '{}' is not in the list.".format(model_family))
" 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
invalidInputError(dtype.lower() in ['int4', 'int8'],
"Now we only support int4 and int8 as date type for weight")
@ -71,3 +72,6 @@ class BigdlNativeForCausalLM:
elif model_family == 'starcoder':
from bigdl.llm.ggml.model.starcoder import Starcoder
return Starcoder(model_path=ggml_model_path, **kwargs)
elif model_family == 'chatglm':
from bigdl.llm.ggml.model.chatglm import ChatGLM
return ChatGLM(model_path=ggml_model_path, **kwargs)

View file

@ -261,6 +261,7 @@ class BaseConverter:
cls.dump_model(f, model, ggml_type)
print(f"{cls.MODEL_TYPE.name} GGML model saved to {save_path}")
return save_path
class ChatGLMConverter(BaseConverter):
@ -397,9 +398,9 @@ def _convert_chatglm_hf_to_ggml_(model_path, outfile_dir, outtype):
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
if hasattr(model.config, "multi_query_attention"):
ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
return ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
else:
ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
return ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
def main():

View file

@ -1596,6 +1596,6 @@ def _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype):
"For now we only support quantization type 'q4_0' and 'q4_1' "
"in chatglm family.")
from bigdl.llm.utils.convert_chatglm import _convert_chatglm_hf_to_ggml_
_convert_chatglm_hf_to_ggml_(model_path,
outfile,
outtype)
return _convert_chatglm_hf_to_ggml_(model_path,
outfile,
outtype)