LLM: add chatglm native int4 transformers API (#8695)
This commit is contained in:
parent
6da830cf7e
commit
ea5d7aff5b
10 changed files with 94 additions and 19 deletions
|
|
@ -124,7 +124,7 @@ def main():
|
||||||
help=("outfile,save path of output quantized model."))
|
help=("outfile,save path of output quantized model."))
|
||||||
parser.add_argument('-x', '--model-family', type=str, required=True,
|
parser.add_argument('-x', '--model-family', type=str, required=True,
|
||||||
help=("--model-family: Which model family your input model belongs to."
|
help=("--model-family: Which model family your input model belongs to."
|
||||||
"Now only `llama`/`bloom`/`gptneox` are supported."))
|
"Now only `llama`/`bloom`/`gptneox`/`chatglm` are supported."))
|
||||||
parser.add_argument('-f', '--model-format', type=str, required=True,
|
parser.add_argument('-f', '--model-format', type=str, required=True,
|
||||||
help=("The model type to be convert to a ggml compatible file."
|
help=("The model type to be convert to a ggml compatible file."
|
||||||
"Now only `pth`/`gptq` are supported."))
|
"Now only `pth`/`gptq` are supported."))
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ def _convert_starcoder(model_path, outfile_dir, outtype):
|
||||||
|
|
||||||
|
|
||||||
def _convert_chatglm(model_path, outfile_dir, outtype):
|
def _convert_chatglm(model_path, outfile_dir, outtype):
|
||||||
_convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
|
return _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_ggml(model_path: str, outfile_dir: str,
|
def _convert_to_ggml(model_path: str, outfile_dir: str,
|
||||||
|
|
|
||||||
|
|
@ -80,10 +80,9 @@ def convert_model(input_path: str,
|
||||||
|
|
||||||
# chatglm merges convertion and quantization into one operation.
|
# chatglm merges convertion and quantization into one operation.
|
||||||
if model_family == 'chatglm':
|
if model_family == 'chatglm':
|
||||||
_convert_chatglm(model_path=input_path,
|
return _convert_chatglm(model_path=input_path,
|
||||||
outfile_dir=output_path,
|
outfile_dir=output_path,
|
||||||
outtype=dtype)
|
outtype=dtype)
|
||||||
return
|
|
||||||
|
|
||||||
if tmp_path is not None:
|
if tmp_path is not None:
|
||||||
model_name = Path(input_path).stem
|
model_name = Path(input_path).stem
|
||||||
|
|
|
||||||
|
|
@ -18,3 +18,5 @@
|
||||||
# physically located elsewhere.
|
# physically located elsewhere.
|
||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
from .chatglm import ChatGLM
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM:
|
class ChatGLM(GenerationMixin):
|
||||||
"""High-level Python wrapper for a chatglm.cpp model."""
|
"""High-level Python wrapper for a chatglm.cpp model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -327,7 +327,7 @@ class ChatGLM:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _tokenize(self, text: bytes) -> List[int]:
|
def _tokenize(self, text: bytes, *args) -> List[int]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -339,9 +339,10 @@ class ChatGLM:
|
||||||
Returns:
|
Returns:
|
||||||
A list of tokens.
|
A list of tokens.
|
||||||
"""
|
"""
|
||||||
|
warnings.warn("The parameter `add_bos` is unsupported, please use the default value.")
|
||||||
return chatglm_tokenize(self.ctx, text)
|
return chatglm_tokenize(self.ctx, text)
|
||||||
|
|
||||||
def detokenize(self, tokens: List[int]) -> bytes:
|
def detokenize(self, tokens: List[int]) -> str:
|
||||||
"""Detokenize a list of tokens.
|
"""Detokenize a list of tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -371,3 +372,65 @@ class ChatGLM:
|
||||||
|
|
||||||
def eos_token(self) -> int:
|
def eos_token(self) -> int:
|
||||||
return chatglm_eos_token(self.ctx)
|
return chatglm_eos_token(self.ctx)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
tokens: Sequence[int],
|
||||||
|
top_k: int = 0,
|
||||||
|
top_p: float = 0.7,
|
||||||
|
temp: float = 0.95,
|
||||||
|
repeat_penalty: float = 1.1,
|
||||||
|
reset: bool = True,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
tfs_z: float = 1.0,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
|
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||||
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> llm = ChatGLM(your_model_path)
|
||||||
|
>>> tokens = llm._tokenize(b"Learning English is")
|
||||||
|
>>> for token in llm._generate(tokens):
|
||||||
|
>>> print(llm.detokenize([token]).decode("utf-8", errors="ignore"))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: The prompt tokens.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The generated tokens.
|
||||||
|
"""
|
||||||
|
# TODO: Some parameters are temporarily not supported
|
||||||
|
# Unsupported parameters are checked in `_supported_generate`
|
||||||
|
return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset,
|
||||||
|
frequency_penalty, presence_penalty, tfs_z, mirostat_mode,
|
||||||
|
mirostat_tau, mirostat_eta)
|
||||||
|
|
||||||
|
def _supported_generate(self, tokens: Sequence[int], top_k: int = 0, top_p: float = 0.7,
|
||||||
|
temp: float = 0.95, *args):
|
||||||
|
# Check unsupporeted parameters
|
||||||
|
unsupported_arg = ['repeat_penalty', 'reset', 'frequency_penalty', 'presence_penalty',
|
||||||
|
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']
|
||||||
|
defult_value = {'repeat_penalty': 1.1, 'reset': True, 'frequency_penalty': 0.0,
|
||||||
|
'presence_penalty': 0.0, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||||
|
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||||
|
for index in range(len(args)):
|
||||||
|
if args[index] != defult_value[unsupported_arg[index]]:
|
||||||
|
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||||
|
"unsupported, please use the default value.")
|
||||||
|
|
||||||
|
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
|
||||||
|
n_past = 0
|
||||||
|
while True:
|
||||||
|
token = self.forward(input_ids=tokens,
|
||||||
|
n_past=n_past,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temp)
|
||||||
|
n_past += len(tokens)
|
||||||
|
tokens_or_none = yield token
|
||||||
|
tokens = [token]
|
||||||
|
if tokens_or_none is not None:
|
||||||
|
tokens.extend(tokens_or_none)
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,11 @@ class GenerationMixin:
|
||||||
:param tokens: list of ids that indicates the tokens, mostly generated by generate
|
:param tokens: list of ids that indicates the tokens, mostly generated by generate
|
||||||
:return: decoded string
|
:return: decoded string
|
||||||
'''
|
'''
|
||||||
return self.detokenize(tokens).decode()
|
output = self.detokenize(tokens)
|
||||||
|
if isinstance(output, str):
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return output.decode()
|
||||||
|
|
||||||
def batch_decode(self,
|
def batch_decode(self,
|
||||||
tokens: Union[List[int], List[List[int]]]) -> str:
|
tokens: Union[List[int], List[List[int]]]) -> str:
|
||||||
|
|
|
||||||
|
|
@ -23,3 +23,5 @@ from bigdl.llm.ggml.model.llama import Llama
|
||||||
from bigdl.llm.ggml.model.gptneox import Gptneox
|
from bigdl.llm.ggml.model.gptneox import Gptneox
|
||||||
from bigdl.llm.ggml.model.bloom import Bloom
|
from bigdl.llm.ggml.model.bloom import Bloom
|
||||||
from bigdl.llm.ggml.model.starcoder import Starcoder
|
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||||
|
# temporarily disable until linux binary file for chatglm ready
|
||||||
|
# from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||||
|
|
|
||||||
|
|
@ -38,12 +38,13 @@ class BigdlNativeForCausalLM:
|
||||||
:param pretrained_model_name_or_path: Path for converted BigDL-LLM optimized ggml
|
:param pretrained_model_name_or_path: Path for converted BigDL-LLM optimized ggml
|
||||||
binary checkpoint. The checkpoint should be converted by ``bigdl.llm.llm_convert``.
|
binary checkpoint. The checkpoint should be converted by ``bigdl.llm.llm_convert``.
|
||||||
:param model_family: The model family of the pretrained checkpoint.
|
:param model_family: The model family of the pretrained checkpoint.
|
||||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
|
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``, ``"starcoder"``
|
||||||
|
and ``"chatglm"``.
|
||||||
:param dtype: Which quantized precision will be converted.
|
:param dtype: Which quantized precision will be converted.
|
||||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||||
, `gptneox` and `starcoder`.
|
, `gptneox` and `starcoder`.
|
||||||
:param cache_dir: (optional) This parameter will only be used when
|
:param cache_dir: (optional) This parameter will only be used when
|
||||||
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
``pretrained_model_name_or_path`` is a huggingface checkpoint or hub repo id.
|
||||||
It indicates the saving path for the converted low precision model.
|
It indicates the saving path for the converted low precision model.
|
||||||
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
|
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
|
||||||
conversion process. Default to `None` so that intermediate model will not be saved.
|
conversion process. Default to `None` so that intermediate model will not be saved.
|
||||||
|
|
@ -51,9 +52,9 @@ class BigdlNativeForCausalLM:
|
||||||
|
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
|
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
|
||||||
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
||||||
" 'starcoder', '{}' is not in the list.".format(model_family))
|
" 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
|
||||||
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||||
"Now we only support int4 and int8 as date type for weight")
|
"Now we only support int4 and int8 as date type for weight")
|
||||||
|
|
||||||
|
|
@ -71,3 +72,6 @@ class BigdlNativeForCausalLM:
|
||||||
elif model_family == 'starcoder':
|
elif model_family == 'starcoder':
|
||||||
from bigdl.llm.ggml.model.starcoder import Starcoder
|
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||||
return Starcoder(model_path=ggml_model_path, **kwargs)
|
return Starcoder(model_path=ggml_model_path, **kwargs)
|
||||||
|
elif model_family == 'chatglm':
|
||||||
|
from bigdl.llm.ggml.model.chatglm import ChatGLM
|
||||||
|
return ChatGLM(model_path=ggml_model_path, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -261,6 +261,7 @@ class BaseConverter:
|
||||||
cls.dump_model(f, model, ggml_type)
|
cls.dump_model(f, model, ggml_type)
|
||||||
|
|
||||||
print(f"{cls.MODEL_TYPE.name} GGML model saved to {save_path}")
|
print(f"{cls.MODEL_TYPE.name} GGML model saved to {save_path}")
|
||||||
|
return save_path
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMConverter(BaseConverter):
|
class ChatGLMConverter(BaseConverter):
|
||||||
|
|
@ -397,9 +398,9 @@ def _convert_chatglm_hf_to_ggml_(model_path, outfile_dir, outtype):
|
||||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
if hasattr(model.config, "multi_query_attention"):
|
if hasattr(model.config, "multi_query_attention"):
|
||||||
ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
|
return ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
|
||||||
else:
|
else:
|
||||||
ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
|
return ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -1596,6 +1596,6 @@ def _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype):
|
||||||
"For now we only support quantization type 'q4_0' and 'q4_1' "
|
"For now we only support quantization type 'q4_0' and 'q4_1' "
|
||||||
"in chatglm family.")
|
"in chatglm family.")
|
||||||
from bigdl.llm.utils.convert_chatglm import _convert_chatglm_hf_to_ggml_
|
from bigdl.llm.utils.convert_chatglm import _convert_chatglm_hf_to_ggml_
|
||||||
_convert_chatglm_hf_to_ggml_(model_path,
|
return _convert_chatglm_hf_to_ggml_(model_path,
|
||||||
outfile,
|
outfile,
|
||||||
outtype)
|
outtype)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue