LLM: add chatglm native int4 transformers API (#8695)
This commit is contained in:
parent
6da830cf7e
commit
ea5d7aff5b
10 changed files with 94 additions and 19 deletions
|
|
@ -124,7 +124,7 @@ def main():
|
|||
help=("outfile,save path of output quantized model."))
|
||||
parser.add_argument('-x', '--model-family', type=str, required=True,
|
||||
help=("--model-family: Which model family your input model belongs to."
|
||||
"Now only `llama`/`bloom`/`gptneox` are supported."))
|
||||
"Now only `llama`/`bloom`/`gptneox`/`chatglm` are supported."))
|
||||
parser.add_argument('-f', '--model-format', type=str, required=True,
|
||||
help=("The model type to be convert to a ggml compatible file."
|
||||
"Now only `pth`/`gptq` are supported."))
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def _convert_starcoder(model_path, outfile_dir, outtype):
|
|||
|
||||
|
||||
def _convert_chatglm(model_path, outfile_dir, outtype):
|
||||
_convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
|
||||
return _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
|
||||
|
||||
|
||||
def _convert_to_ggml(model_path: str, outfile_dir: str,
|
||||
|
|
|
|||
|
|
@ -80,10 +80,9 @@ def convert_model(input_path: str,
|
|||
|
||||
# chatglm merges convertion and quantization into one operation.
|
||||
if model_family == 'chatglm':
|
||||
_convert_chatglm(model_path=input_path,
|
||||
return _convert_chatglm(model_path=input_path,
|
||||
outfile_dir=output_path,
|
||||
outtype=dtype)
|
||||
return
|
||||
|
||||
if tmp_path is not None:
|
||||
model_name = Path(input_path).stem
|
||||
|
|
|
|||
|
|
@ -18,3 +18,5 @@
|
|||
# physically located elsewhere.
|
||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
from .chatglm import ChatGLM
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ import uuid
|
|||
import warnings
|
||||
|
||||
|
||||
class ChatGLM:
|
||||
class ChatGLM(GenerationMixin):
|
||||
"""High-level Python wrapper for a chatglm.cpp model."""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -327,7 +327,7 @@ class ChatGLM:
|
|||
}
|
||||
}
|
||||
|
||||
def _tokenize(self, text: bytes) -> List[int]:
|
||||
def _tokenize(self, text: bytes, *args) -> List[int]:
|
||||
"""Tokenize a string.
|
||||
|
||||
Args:
|
||||
|
|
@ -339,9 +339,10 @@ class ChatGLM:
|
|||
Returns:
|
||||
A list of tokens.
|
||||
"""
|
||||
warnings.warn("The parameter `add_bos` is unsupported, please use the default value.")
|
||||
return chatglm_tokenize(self.ctx, text)
|
||||
|
||||
def detokenize(self, tokens: List[int]) -> bytes:
|
||||
def detokenize(self, tokens: List[int]) -> str:
|
||||
"""Detokenize a list of tokens.
|
||||
|
||||
Args:
|
||||
|
|
@ -371,3 +372,65 @@ class ChatGLM:
|
|||
|
||||
def eos_token(self) -> int:
|
||||
return chatglm_eos_token(self.ctx)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
tokens: Sequence[int],
|
||||
top_k: int = 0,
|
||||
top_p: float = 0.7,
|
||||
temp: float = 0.95,
|
||||
repeat_penalty: float = 1.1,
|
||||
reset: bool = True,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||
"""Create a generator of tokens from a prompt.
|
||||
|
||||
Examples:
|
||||
>>> llm = ChatGLM(your_model_path)
|
||||
>>> tokens = llm._tokenize(b"Learning English is")
|
||||
>>> for token in llm._generate(tokens):
|
||||
>>> print(llm.detokenize([token]).decode("utf-8", errors="ignore"))
|
||||
|
||||
Args:
|
||||
tokens: The prompt tokens.
|
||||
|
||||
Yields:
|
||||
The generated tokens.
|
||||
"""
|
||||
# TODO: Some parameters are temporarily not supported
|
||||
# Unsupported parameters are checked in `_supported_generate`
|
||||
return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset,
|
||||
frequency_penalty, presence_penalty, tfs_z, mirostat_mode,
|
||||
mirostat_tau, mirostat_eta)
|
||||
|
||||
def _supported_generate(self, tokens: Sequence[int], top_k: int = 0, top_p: float = 0.7,
|
||||
temp: float = 0.95, *args):
|
||||
# Check unsupporeted parameters
|
||||
unsupported_arg = ['repeat_penalty', 'reset', 'frequency_penalty', 'presence_penalty',
|
||||
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']
|
||||
defult_value = {'repeat_penalty': 1.1, 'reset': True, 'frequency_penalty': 0.0,
|
||||
'presence_penalty': 0.0, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||
for index in range(len(args)):
|
||||
if args[index] != defult_value[unsupported_arg[index]]:
|
||||
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
"unsupported, please use the default value.")
|
||||
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
|
||||
n_past = 0
|
||||
while True:
|
||||
token = self.forward(input_ids=tokens,
|
||||
n_past=n_past,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temp)
|
||||
n_past += len(tokens)
|
||||
tokens_or_none = yield token
|
||||
tokens = [token]
|
||||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,11 @@ class GenerationMixin:
|
|||
:param tokens: list of ids that indicates the tokens, mostly generated by generate
|
||||
:return: decoded string
|
||||
'''
|
||||
return self.detokenize(tokens).decode()
|
||||
output = self.detokenize(tokens)
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
else:
|
||||
return output.decode()
|
||||
|
||||
def batch_decode(self,
|
||||
tokens: Union[List[int], List[List[int]]]) -> str:
|
||||
|
|
|
|||
|
|
@ -23,3 +23,5 @@ from bigdl.llm.ggml.model.llama import Llama
|
|||
from bigdl.llm.ggml.model.gptneox import Gptneox
|
||||
from bigdl.llm.ggml.model.bloom import Bloom
|
||||
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||
# temporarily disable until linux binary file for chatglm ready
|
||||
# from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||
|
|
|
|||
|
|
@ -38,12 +38,13 @@ class BigdlNativeForCausalLM:
|
|||
:param pretrained_model_name_or_path: Path for converted BigDL-LLM optimized ggml
|
||||
binary checkpoint. The checkpoint should be converted by ``bigdl.llm.llm_convert``.
|
||||
:param model_family: The model family of the pretrained checkpoint.
|
||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
|
||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``, ``"starcoder"``
|
||||
and ``"chatglm"``.
|
||||
:param dtype: Which quantized precision will be converted.
|
||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||
, `gptneox` and `starcoder`.
|
||||
:param cache_dir: (optional) This parameter will only be used when
|
||||
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
||||
``pretrained_model_name_or_path`` is a huggingface checkpoint or hub repo id.
|
||||
It indicates the saving path for the converted low precision model.
|
||||
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
|
||||
conversion process. Default to `None` so that intermediate model will not be saved.
|
||||
|
|
@ -51,9 +52,9 @@ class BigdlNativeForCausalLM:
|
|||
|
||||
:return: a model instance
|
||||
"""
|
||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
|
||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
|
||||
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
||||
" 'starcoder', '{}' is not in the list.".format(model_family))
|
||||
" 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
|
||||
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||
"Now we only support int4 and int8 as date type for weight")
|
||||
|
||||
|
|
@ -71,3 +72,6 @@ class BigdlNativeForCausalLM:
|
|||
elif model_family == 'starcoder':
|
||||
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||
return Starcoder(model_path=ggml_model_path, **kwargs)
|
||||
elif model_family == 'chatglm':
|
||||
from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||
return ChatGLM(model_path=ggml_model_path, **kwargs)
|
||||
|
|
|
|||
|
|
@ -261,6 +261,7 @@ class BaseConverter:
|
|||
cls.dump_model(f, model, ggml_type)
|
||||
|
||||
print(f"{cls.MODEL_TYPE.name} GGML model saved to {save_path}")
|
||||
return save_path
|
||||
|
||||
|
||||
class ChatGLMConverter(BaseConverter):
|
||||
|
|
@ -397,9 +398,9 @@ def _convert_chatglm_hf_to_ggml_(model_path, outfile_dir, outtype):
|
|||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
if hasattr(model.config, "multi_query_attention"):
|
||||
ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
|
||||
return ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
|
||||
else:
|
||||
ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
|
||||
return ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -1596,6 +1596,6 @@ def _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype):
|
|||
"For now we only support quantization type 'q4_0' and 'q4_1' "
|
||||
"in chatglm family.")
|
||||
from bigdl.llm.utils.convert_chatglm import _convert_chatglm_hf_to_ggml_
|
||||
_convert_chatglm_hf_to_ggml_(model_path,
|
||||
return _convert_chatglm_hf_to_ggml_(model_path,
|
||||
outfile,
|
||||
outtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue