LLM: fix bugs during supporting bloom in langchain (#8362)
This commit is contained in:
		
							parent
							
								
									066f53232d
								
							
						
					
					
						commit
						5f4f399ca7
					
				
					 1 changed files with 66 additions and 52 deletions
				
			
		| 
						 | 
				
			
			@ -46,7 +46,7 @@
 | 
			
		|||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
 | 
			
		||||
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval
 | 
			
		||||
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval, bloom_embed
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.ggml.model.generation import GenerationMixin
 | 
			
		||||
from typing import List, Optional, Generator, Sequence, Union
 | 
			
		||||
| 
						 | 
				
			
			@ -127,7 +127,7 @@ class Bloom(GenerationMixin):
 | 
			
		|||
        # TODO: Some parameters are temporarily not supported
 | 
			
		||||
        unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False,
 | 
			
		||||
                           'vocab_only': False, 'use_mmap': True, 'use_mlock': False,
 | 
			
		||||
                           'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None,
 | 
			
		||||
                           'last_n_tokens_size': 64, 'lora_base': None,
 | 
			
		||||
                           'lora_path': None, 'verbose': True}
 | 
			
		||||
        for arg in unsupported_arg.keys():
 | 
			
		||||
            invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
 | 
			
		||||
| 
						 | 
				
			
			@ -156,44 +156,50 @@ class Bloom(GenerationMixin):
 | 
			
		|||
    ):
 | 
			
		||||
        # TODO: Some parameters are temporarily not supported
 | 
			
		||||
        # Unsupported parameters are checked in `_supported_call`
 | 
			
		||||
        return self._supported_call(prompt, max_tokens, stream, stop,
 | 
			
		||||
                                    suffix, temperature, top_p, logprobs, echo, frequency_penalty,
 | 
			
		||||
        return self._supported_call(prompt, max_tokens, stream, stop, echo, model,
 | 
			
		||||
                                    suffix, temperature, top_p, logprobs, frequency_penalty,
 | 
			
		||||
                                    presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode,
 | 
			
		||||
                                    mirostat_tau, mirostat_eta, model)
 | 
			
		||||
                                    mirostat_tau, mirostat_eta)
 | 
			
		||||
 | 
			
		||||
    def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False,
 | 
			
		||||
                        stop: Optional[List[str]] = [], *args):
 | 
			
		||||
                        stop: Optional[List[str]] = [], echo: bool = False,
 | 
			
		||||
                        model: Optional[str] = None, *args):
 | 
			
		||||
        # Check unsupporeted parameters
 | 
			
		||||
        unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo',
 | 
			
		||||
        unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs',
 | 
			
		||||
                           'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k',
 | 
			
		||||
                           'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model']
 | 
			
		||||
        defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None,
 | 
			
		||||
                        'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
 | 
			
		||||
                        'frequency_penalty': 0.0, 'presence_penalty': 0.0,
 | 
			
		||||
                        'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
 | 
			
		||||
                        'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None}
 | 
			
		||||
                        'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
 | 
			
		||||
        for index in range(len(args)):
 | 
			
		||||
            invalidInputError(args[index] == defult_value[unsupported_arg[index]],
 | 
			
		||||
                              f"The parameter {unsupported_arg[index]} is temporarily "
 | 
			
		||||
                              "unsupported, please use the default value.")
 | 
			
		||||
 | 
			
		||||
        if stream:
 | 
			
		||||
            return self.stream(prompt, max_tokens, stop)
 | 
			
		||||
            return self.stream(prompt, max_tokens, stop, echo, model)
 | 
			
		||||
        else:
 | 
			
		||||
            return self._eval(prompt, max_tokens, False, stop)
 | 
			
		||||
            return self._eval(prompt, max_tokens, False, stop, echo, model)
 | 
			
		||||
 | 
			
		||||
    def _eval(self, prompt: str, max_tokens: int, match_str: bool,
 | 
			
		||||
              stop: Optional[List[str]] = []):
 | 
			
		||||
              stop: Optional[List[str]] = [], echo: bool = False, model: Optional[str] = None):
 | 
			
		||||
        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
			
		||||
        created: int = int(time.time())
 | 
			
		||||
        if model is None:
 | 
			
		||||
            model_name = self.model_path
 | 
			
		||||
        else:
 | 
			
		||||
            model_name = model
 | 
			
		||||
        prompt_len = len(self.tokenize(prompt))
 | 
			
		||||
        if prompt.endswith("</s>") or max_tokens < 1:
 | 
			
		||||
            return {
 | 
			
		||||
                "id": completion_id,
 | 
			
		||||
                "object": "text_completion",
 | 
			
		||||
                "created": created,
 | 
			
		||||
                "model": self.model_path,
 | 
			
		||||
                "model": model_name,
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    {
 | 
			
		||||
                        "text": prompt,
 | 
			
		||||
                        "text": prompt if echo else "",
 | 
			
		||||
                        "index": 0,
 | 
			
		||||
                        "logprobs": None,
 | 
			
		||||
                        "finish_reason": "length",
 | 
			
		||||
| 
						 | 
				
			
			@ -201,15 +207,14 @@ class Bloom(GenerationMixin):
 | 
			
		|||
                ],
 | 
			
		||||
                "usage":
 | 
			
		||||
                {
 | 
			
		||||
                    # TODO: need tokenize
 | 
			
		||||
                    "prompt_tokens": None,
 | 
			
		||||
                    "completion_tokens": None,
 | 
			
		||||
                    "total_tokens": None,
 | 
			
		||||
                    "prompt_tokens": prompt_len,
 | 
			
		||||
                    "completion_tokens": 0,
 | 
			
		||||
                    "total_tokens": prompt_len,
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        # use `buf` to store prompt and generated string,
 | 
			
		||||
        # assume the average length of words is less than 20 bytes
 | 
			
		||||
        buf = bytes((len(prompt) + max_tokens) * 20)
 | 
			
		||||
        buf = bytes((prompt_len + max_tokens) * 20)
 | 
			
		||||
        ret = bloom_run(ctx=self.ctx,
 | 
			
		||||
                        seed=self.seed,
 | 
			
		||||
                        n_threads=self.n_threads,
 | 
			
		||||
| 
						 | 
				
			
			@ -229,13 +234,14 @@ class Bloom(GenerationMixin):
 | 
			
		|||
            finish_reason = "stop"
 | 
			
		||||
        else:
 | 
			
		||||
            finish_reason = None
 | 
			
		||||
        completion_len = len(self.tokenize(split_text))
 | 
			
		||||
        return {"id": completion_id,
 | 
			
		||||
                "object": "text_completion",
 | 
			
		||||
                "created": created,
 | 
			
		||||
                "model": self.model_path,
 | 
			
		||||
                "model": model_name,
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    {
 | 
			
		||||
                        "text": prompt + split_text,
 | 
			
		||||
                        "text": prompt + split_text if echo else split_text,
 | 
			
		||||
                        "index": 0,
 | 
			
		||||
                        "logprobs": None,
 | 
			
		||||
                        "finish_reason": finish_reason,
 | 
			
		||||
| 
						 | 
				
			
			@ -243,25 +249,31 @@ class Bloom(GenerationMixin):
 | 
			
		|||
                ],
 | 
			
		||||
                "usage":
 | 
			
		||||
                {
 | 
			
		||||
                    # TODO: need tokenize
 | 
			
		||||
                    "prompt_tokens": None,
 | 
			
		||||
                    "completion_tokens": None,
 | 
			
		||||
                    "total_tokens": None,
 | 
			
		||||
                    "prompt_tokens": prompt_len,
 | 
			
		||||
                    "completion_tokens": completion_len,
 | 
			
		||||
                    "total_tokens": prompt_len + completion_len,
 | 
			
		||||
                }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
    def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
 | 
			
		||||
    def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = [],
 | 
			
		||||
               echo: bool = False, model: Optional[str] = None):
 | 
			
		||||
        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
			
		||||
        created: int = int(time.time())
 | 
			
		||||
        if model is None:
 | 
			
		||||
            model_name = self.model_path
 | 
			
		||||
        else:
 | 
			
		||||
            model_name = model
 | 
			
		||||
        prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8"))
 | 
			
		||||
        prompt_len = len(prompt_tokens)
 | 
			
		||||
        if prompt.endswith("</s>") or max_tokens < 1:
 | 
			
		||||
            yield {
 | 
			
		||||
                "id": completion_id,
 | 
			
		||||
                "object": "text_completion",
 | 
			
		||||
                "created": created,
 | 
			
		||||
                "model": self.model_path,
 | 
			
		||||
                "model": model_name,
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    {
 | 
			
		||||
                        "text": prompt,
 | 
			
		||||
                        "text": prompt if echo else "",
 | 
			
		||||
                        "index": 0,
 | 
			
		||||
                        "logprobs": None,
 | 
			
		||||
                        "finish_reason": "length",
 | 
			
		||||
| 
						 | 
				
			
			@ -269,35 +281,35 @@ class Bloom(GenerationMixin):
 | 
			
		|||
                ],
 | 
			
		||||
                "usage":
 | 
			
		||||
                    {
 | 
			
		||||
                        # TODO: need tokenize
 | 
			
		||||
                        "prompt_tokens": None
 | 
			
		||||
                        "prompt_tokens": prompt_len
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            for i in range(max_tokens):
 | 
			
		||||
                if prompt.endswith("</s>"):
 | 
			
		||||
                token = self.forward(prompt_tokens)
 | 
			
		||||
                prompt_tokens.append(token)
 | 
			
		||||
                text = self.detokenize([token]).decode("utf-8", errors="ignore")
 | 
			
		||||
                if text.endswith("</s>"):
 | 
			
		||||
                    print('\n')
 | 
			
		||||
                    break
 | 
			
		||||
                else:
 | 
			
		||||
                    prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
 | 
			
		||||
                    yield {
 | 
			
		||||
                        "id": completion_id,
 | 
			
		||||
                        "object": "text_completion",
 | 
			
		||||
                        "created": created,
 | 
			
		||||
                        "model": self.model_path,
 | 
			
		||||
                        "choices": [
 | 
			
		||||
                            {
 | 
			
		||||
                                "text": prompt,
 | 
			
		||||
                                "index": 0,
 | 
			
		||||
                                "logprobs": None,
 | 
			
		||||
                                "finish_reason": None,
 | 
			
		||||
                            }
 | 
			
		||||
                        ],
 | 
			
		||||
                        "usage":
 | 
			
		||||
                            {
 | 
			
		||||
                                # TODO: need tokenize
 | 
			
		||||
                                "prompt_tokens": None
 | 
			
		||||
                yield {
 | 
			
		||||
                    "id": completion_id,
 | 
			
		||||
                    "object": "text_completion",
 | 
			
		||||
                    "created": created,
 | 
			
		||||
                    "model": model_name,
 | 
			
		||||
                    "choices": [
 | 
			
		||||
                        {
 | 
			
		||||
                            "text": text,
 | 
			
		||||
                            "index": 0,
 | 
			
		||||
                            "logprobs": None,
 | 
			
		||||
                            "finish_reason": None,
 | 
			
		||||
                        }
 | 
			
		||||
                    ],
 | 
			
		||||
                    "usage":
 | 
			
		||||
                        {
 | 
			
		||||
                            "prompt_tokens": prompt_len
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
    def free(self):
 | 
			
		||||
        bloom_free(self.ctx)
 | 
			
		||||
| 
						 | 
				
			
			@ -403,9 +415,11 @@ class Bloom(GenerationMixin):
 | 
			
		|||
            if tokens_or_none is not None:
 | 
			
		||||
                tokens.extend(tokens_or_none)
 | 
			
		||||
 | 
			
		||||
    def embed(self, prompt: Union[str, bytes]) -> List[float]:
 | 
			
		||||
    def embed(self, input: str) -> List[float]:
 | 
			
		||||
        """Only used for langchain"""
 | 
			
		||||
        input_ids = self.tokenize(prompt)
 | 
			
		||||
        invalidInputError(self.embedding,
 | 
			
		||||
                          "Bloom model must be created with embedding=True to call this method.")
 | 
			
		||||
        input_ids = self.tokenize(input)
 | 
			
		||||
        return bloom_embed(ctx=self.ctx,
 | 
			
		||||
                           input_ids=input_ids,
 | 
			
		||||
                           seed=self.seed,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue