LLM: add chatglm native int4 transformers API (#8695)
This commit is contained in:
		
							parent
							
								
									6da830cf7e
								
							
						
					
					
						commit
						ea5d7aff5b
					
				
					 10 changed files with 94 additions and 19 deletions
				
			
		| 
						 | 
				
			
			@ -124,7 +124,7 @@ def main():
 | 
			
		|||
                        help=("outfile,save path of output quantized model."))
 | 
			
		||||
    parser.add_argument('-x', '--model-family', type=str, required=True,
 | 
			
		||||
                        help=("--model-family: Which model family your input model belongs to."
 | 
			
		||||
                              "Now only `llama`/`bloom`/`gptneox` are supported."))
 | 
			
		||||
                              "Now only `llama`/`bloom`/`gptneox`/`chatglm` are supported."))
 | 
			
		||||
    parser.add_argument('-f', '--model-format', type=str, required=True,
 | 
			
		||||
                        help=("The model type to be convert to a ggml compatible file."
 | 
			
		||||
                              "Now only `pth`/`gptq` are supported."))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -77,7 +77,7 @@ def _convert_starcoder(model_path, outfile_dir, outtype):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _convert_chatglm(model_path, outfile_dir, outtype):
 | 
			
		||||
    _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
 | 
			
		||||
    return _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_to_ggml(model_path: str, outfile_dir: str,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -80,10 +80,9 @@ def convert_model(input_path: str,
 | 
			
		|||
 | 
			
		||||
    # chatglm merges convertion and quantization into one operation.
 | 
			
		||||
    if model_family == 'chatglm':
 | 
			
		||||
        _convert_chatglm(model_path=input_path,
 | 
			
		||||
                         outfile_dir=output_path,
 | 
			
		||||
                         outtype=dtype)
 | 
			
		||||
        return
 | 
			
		||||
        return _convert_chatglm(model_path=input_path,
 | 
			
		||||
                                outfile_dir=output_path,
 | 
			
		||||
                                outtype=dtype)
 | 
			
		||||
 | 
			
		||||
    if tmp_path is not None:
 | 
			
		||||
        model_name = Path(input_path).stem
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,3 +18,5 @@
 | 
			
		|||
# physically located elsewhere.
 | 
			
		||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
			
		||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
from .chatglm import ChatGLM
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -56,7 +56,7 @@ import uuid
 | 
			
		|||
import warnings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatGLM:
 | 
			
		||||
class ChatGLM(GenerationMixin):
 | 
			
		||||
    """High-level Python wrapper for a chatglm.cpp model."""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
| 
						 | 
				
			
			@ -327,7 +327,7 @@ class ChatGLM:
 | 
			
		|||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
    def _tokenize(self, text: bytes) -> List[int]:
 | 
			
		||||
    def _tokenize(self, text: bytes, *args) -> List[int]:
 | 
			
		||||
        """Tokenize a string.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
| 
						 | 
				
			
			@ -339,9 +339,10 @@ class ChatGLM:
 | 
			
		|||
        Returns:
 | 
			
		||||
            A list of tokens.
 | 
			
		||||
        """
 | 
			
		||||
        warnings.warn("The parameter `add_bos` is unsupported, please use the default value.")
 | 
			
		||||
        return chatglm_tokenize(self.ctx, text)
 | 
			
		||||
 | 
			
		||||
    def detokenize(self, tokens: List[int]) -> bytes:
 | 
			
		||||
    def detokenize(self, tokens: List[int]) -> str:
 | 
			
		||||
        """Detokenize a list of tokens.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
| 
						 | 
				
			
			@ -371,3 +372,65 @@ class ChatGLM:
 | 
			
		|||
 | 
			
		||||
    def eos_token(self) -> int:
 | 
			
		||||
        return chatglm_eos_token(self.ctx)
 | 
			
		||||
 | 
			
		||||
    def _generate(
 | 
			
		||||
        self,
 | 
			
		||||
        tokens: Sequence[int],
 | 
			
		||||
        top_k: int = 0,
 | 
			
		||||
        top_p: float = 0.7,
 | 
			
		||||
        temp: float = 0.95,
 | 
			
		||||
        repeat_penalty: float = 1.1,
 | 
			
		||||
        reset: bool = True,
 | 
			
		||||
        frequency_penalty: float = 0.0,
 | 
			
		||||
        presence_penalty: float = 0.0,
 | 
			
		||||
        tfs_z: float = 1.0,
 | 
			
		||||
        mirostat_mode: int = 0,
 | 
			
		||||
        mirostat_tau: float = 5.0,
 | 
			
		||||
        mirostat_eta: float = 0.1,
 | 
			
		||||
    ) -> Generator[int, Optional[Sequence[int]], None]:
 | 
			
		||||
        """Create a generator of tokens from a prompt.
 | 
			
		||||
 | 
			
		||||
        Examples:
 | 
			
		||||
            >>> llm = ChatGLM(your_model_path)
 | 
			
		||||
            >>> tokens = llm._tokenize(b"Learning English is")
 | 
			
		||||
            >>> for token in llm._generate(tokens):
 | 
			
		||||
            >>>     print(llm.detokenize([token]).decode("utf-8", errors="ignore"))
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            tokens: The prompt tokens.
 | 
			
		||||
 | 
			
		||||
        Yields:
 | 
			
		||||
            The generated tokens.
 | 
			
		||||
        """
 | 
			
		||||
        # TODO: Some parameters are temporarily not supported
 | 
			
		||||
        # Unsupported parameters are checked in `_supported_generate`
 | 
			
		||||
        return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset,
 | 
			
		||||
                                        frequency_penalty, presence_penalty, tfs_z, mirostat_mode,
 | 
			
		||||
                                        mirostat_tau, mirostat_eta)
 | 
			
		||||
 | 
			
		||||
    def _supported_generate(self, tokens: Sequence[int], top_k: int = 0, top_p: float = 0.7,
 | 
			
		||||
                            temp: float = 0.95, *args):
 | 
			
		||||
        # Check unsupporeted parameters
 | 
			
		||||
        unsupported_arg = ['repeat_penalty', 'reset', 'frequency_penalty', 'presence_penalty',
 | 
			
		||||
                           'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']
 | 
			
		||||
        defult_value = {'repeat_penalty': 1.1, 'reset': True, 'frequency_penalty': 0.0,
 | 
			
		||||
                        'presence_penalty': 0.0, 'tfs_z': 1.0, 'mirostat_mode': 0,
 | 
			
		||||
                        'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
 | 
			
		||||
        for index in range(len(args)):
 | 
			
		||||
            if args[index] != defult_value[unsupported_arg[index]]:
 | 
			
		||||
                warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
 | 
			
		||||
                              "unsupported, please use the default value.")
 | 
			
		||||
 | 
			
		||||
        invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
 | 
			
		||||
        n_past = 0
 | 
			
		||||
        while True:
 | 
			
		||||
            token = self.forward(input_ids=tokens,
 | 
			
		||||
                                 n_past=n_past,
 | 
			
		||||
                                 top_k=top_k,
 | 
			
		||||
                                 top_p=top_p,
 | 
			
		||||
                                 temperature=temp)
 | 
			
		||||
            n_past += len(tokens)
 | 
			
		||||
            tokens_or_none = yield token
 | 
			
		||||
            tokens = [token]
 | 
			
		||||
            if tokens_or_none is not None:
 | 
			
		||||
                tokens.extend(tokens_or_none)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,7 +72,11 @@ class GenerationMixin:
 | 
			
		|||
        :param tokens: list of ids that indicates the tokens, mostly generated by generate
 | 
			
		||||
        :return: decoded string
 | 
			
		||||
        '''
 | 
			
		||||
        return self.detokenize(tokens).decode()
 | 
			
		||||
        output = self.detokenize(tokens)
 | 
			
		||||
        if isinstance(output, str):
 | 
			
		||||
            return output
 | 
			
		||||
        else:
 | 
			
		||||
            return output.decode()
 | 
			
		||||
 | 
			
		||||
    def batch_decode(self,
 | 
			
		||||
                     tokens: Union[List[int], List[List[int]]]) -> str:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,3 +23,5 @@ from bigdl.llm.ggml.model.llama import Llama
 | 
			
		|||
from bigdl.llm.ggml.model.gptneox import Gptneox
 | 
			
		||||
from bigdl.llm.ggml.model.bloom import Bloom
 | 
			
		||||
from bigdl.llm.ggml.model.starcoder import Starcoder
 | 
			
		||||
# temporarily disable until linux binary file for chatglm ready
 | 
			
		||||
# from bigdl.llm.ggml.model.chatglm import ChatGLM
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,12 +38,13 @@ class BigdlNativeForCausalLM:
 | 
			
		|||
        :param pretrained_model_name_or_path: Path for converted BigDL-LLM optimized ggml
 | 
			
		||||
               binary checkpoint. The checkpoint should be converted by ``bigdl.llm.llm_convert``.
 | 
			
		||||
        :param model_family: The model family of the pretrained checkpoint.
 | 
			
		||||
               Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
 | 
			
		||||
               Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``, ``"starcoder"``
 | 
			
		||||
               and ``"chatglm"``.
 | 
			
		||||
        :param dtype: Which quantized precision will be converted.
 | 
			
		||||
                Now only `int4` and `int8` are supported, and `int8` only works for `llama`
 | 
			
		||||
                , `gptneox` and `starcoder`.
 | 
			
		||||
        :param cache_dir: (optional) This parameter will only be used when
 | 
			
		||||
               ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
 | 
			
		||||
               ``pretrained_model_name_or_path`` is a huggingface checkpoint or hub repo id.
 | 
			
		||||
               It indicates the saving path for the converted low precision model.
 | 
			
		||||
        :param tmp_path: (optional) Which path to store the intermediate fp16 model during the
 | 
			
		||||
               conversion process. Default to `None` so that intermediate model will not be saved.
 | 
			
		||||
| 
						 | 
				
			
			@ -51,9 +52,9 @@ class BigdlNativeForCausalLM:
 | 
			
		|||
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
        invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
 | 
			
		||||
        invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder', 'chatglm'],
 | 
			
		||||
                          "Now we only support model family: 'llama', 'gptneox', 'bloom',"
 | 
			
		||||
                          " 'starcoder', '{}' is not in the list.".format(model_family))
 | 
			
		||||
                          " 'starcoder', 'chatglm', '{}' is not in the list.".format(model_family))
 | 
			
		||||
        invalidInputError(dtype.lower() in ['int4', 'int8'],
 | 
			
		||||
                          "Now we only support int4 and int8 as date type for weight")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -71,3 +72,6 @@ class BigdlNativeForCausalLM:
 | 
			
		|||
        elif model_family == 'starcoder':
 | 
			
		||||
            from bigdl.llm.ggml.model.starcoder import Starcoder
 | 
			
		||||
            return Starcoder(model_path=ggml_model_path, **kwargs)
 | 
			
		||||
        elif model_family == 'chatglm':
 | 
			
		||||
            from bigdl.llm.ggml.model.chatglm import ChatGLM
 | 
			
		||||
            return ChatGLM(model_path=ggml_model_path, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -261,6 +261,7 @@ class BaseConverter:
 | 
			
		|||
            cls.dump_model(f, model, ggml_type)
 | 
			
		||||
 | 
			
		||||
        print(f"{cls.MODEL_TYPE.name} GGML model saved to {save_path}")
 | 
			
		||||
        return save_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatGLMConverter(BaseConverter):
 | 
			
		||||
| 
						 | 
				
			
			@ -397,9 +398,9 @@ def _convert_chatglm_hf_to_ggml_(model_path, outfile_dir, outtype):
 | 
			
		|||
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    if hasattr(model.config, "multi_query_attention"):
 | 
			
		||||
        ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
 | 
			
		||||
        return ChatGLM2Converter.convert(model, tokenizer, ggml_type, outfile_dir)
 | 
			
		||||
    else:
 | 
			
		||||
        ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
 | 
			
		||||
        return ChatGLMConverter.convert(model, tokenizer, ggml_type, outfile_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1596,6 +1596,6 @@ def _convert_chatglm_hf_to_ggml(model_path, outfile_dir, outtype):
 | 
			
		|||
                      "For now we only support quantization type 'q4_0' and 'q4_1' "
 | 
			
		||||
                      "in chatglm family.")
 | 
			
		||||
    from bigdl.llm.utils.convert_chatglm import _convert_chatglm_hf_to_ggml_
 | 
			
		||||
    _convert_chatglm_hf_to_ggml_(model_path,
 | 
			
		||||
                                 outfile,
 | 
			
		||||
                                 outtype)
 | 
			
		||||
    return _convert_chatglm_hf_to_ggml_(model_path,
 | 
			
		||||
                                        outfile,
 | 
			
		||||
                                        outtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue