LLM: add stop words and enhance output for bloom pybinding (#8280)
This commit is contained in:
		
							parent
							
								
									6990328e5c
								
							
						
					
					
						commit
						f9e2bda04a
					
				
					 1 changed files with 105 additions and 10 deletions
				
			
		| 
						 | 
					@ -47,6 +47,9 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
 | 
					from .bloom_cpp import bloom_load, bloom_free, bloom_run
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from typing import List, Optional
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Bloom:
 | 
					class Bloom:
 | 
				
			||||||
| 
						 | 
					@ -81,6 +84,7 @@ class Bloom:
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            A Bloom instance.
 | 
					            A Bloom instance.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        self.model_path = model_path
 | 
				
			||||||
        self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
 | 
					        self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
 | 
				
			||||||
        invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
 | 
					        invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
 | 
				
			||||||
        self.n_ctx = n_ctx
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
| 
						 | 
					@ -91,15 +95,39 @@ class Bloom:
 | 
				
			||||||
        self.last_n_tokens_size = last_n_tokens_size
 | 
					        self.last_n_tokens_size = last_n_tokens_size
 | 
				
			||||||
        self.verbose = verbose
 | 
					        self.verbose = verbose
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, prompt: str, max_tokens: int, stream: bool = False):
 | 
					    def __call__(self, prompt: str, max_tokens: int, stream: bool = False,
 | 
				
			||||||
 | 
					                 stop: Optional[List[str]] = []):
 | 
				
			||||||
        if stream:
 | 
					        if stream:
 | 
				
			||||||
            return self.stream(prompt, max_tokens)
 | 
					            return self.stream(prompt, max_tokens, stop)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return self._eval(prompt, max_tokens, False)
 | 
					            return self._eval(prompt, max_tokens, False, stop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _eval(self, prompt: str, max_tokens: int, match_str: bool):
 | 
					    def _eval(self, prompt: str, max_tokens: int, match_str: bool,
 | 
				
			||||||
 | 
					              stop: Optional[List[str]] = []):
 | 
				
			||||||
 | 
					        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
				
			||||||
 | 
					        created: int = int(time.time())
 | 
				
			||||||
        if prompt.endswith("</s>") or max_tokens < 1:
 | 
					        if prompt.endswith("</s>") or max_tokens < 1:
 | 
				
			||||||
            return prompt
 | 
					            return {
 | 
				
			||||||
 | 
					                "id": completion_id,
 | 
				
			||||||
 | 
					                "object": "text_completion",
 | 
				
			||||||
 | 
					                "created": created,
 | 
				
			||||||
 | 
					                "model": self.model_path,
 | 
				
			||||||
 | 
					                "choices": [
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "text": prompt,
 | 
				
			||||||
 | 
					                        "index": 0,
 | 
				
			||||||
 | 
					                        "logprobs": None,
 | 
				
			||||||
 | 
					                        "finish_reason": "length",
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                "usage":
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    # TODO: need tokenize
 | 
				
			||||||
 | 
					                    "prompt_tokens": None,
 | 
				
			||||||
 | 
					                    "completion_tokens": None,
 | 
				
			||||||
 | 
					                    "total_tokens": None,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
        # use `buf` to store prompt and generated string,
 | 
					        # use `buf` to store prompt and generated string,
 | 
				
			||||||
        # assume the average length of words is less than 20 bytes
 | 
					        # assume the average length of words is less than 20 bytes
 | 
				
			||||||
        buf = bytes((len(prompt) + max_tokens) * 20)
 | 
					        buf = bytes((len(prompt) + max_tokens) * 20)
 | 
				
			||||||
| 
						 | 
					@ -112,18 +140,85 @@ class Bloom:
 | 
				
			||||||
                        prompt=bytes(prompt, encoding='utf-8'),
 | 
					                        prompt=bytes(prompt, encoding='utf-8'),
 | 
				
			||||||
                        buf=buf)
 | 
					                        buf=buf)
 | 
				
			||||||
        s = str(buf, encoding='utf-8').rstrip("\x00")
 | 
					        s = str(buf, encoding='utf-8').rstrip("\x00")
 | 
				
			||||||
        return s
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def stream(self, prompt: str, max_tokens: int):
 | 
					        text = s.split(prompt)[1]
 | 
				
			||||||
 | 
					        split_text = text
 | 
				
			||||||
 | 
					        if stop != []:
 | 
				
			||||||
 | 
					            for stop_word in stop:
 | 
				
			||||||
 | 
					                split_text = split_text.split(stop_word)[0]
 | 
				
			||||||
 | 
					        if split_text != text:
 | 
				
			||||||
 | 
					            finish_reason = "stop"
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            finish_reason = None
 | 
				
			||||||
 | 
					        return {"id": completion_id,
 | 
				
			||||||
 | 
					                "object": "text_completion",
 | 
				
			||||||
 | 
					                "created": created,
 | 
				
			||||||
 | 
					                "model": self.model_path,
 | 
				
			||||||
 | 
					                "choices": [
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "text": prompt + split_text,
 | 
				
			||||||
 | 
					                        "index": 0,
 | 
				
			||||||
 | 
					                        "logprobs": None,
 | 
				
			||||||
 | 
					                        "finish_reason": finish_reason,
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                "usage":
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    # TODO: need tokenize
 | 
				
			||||||
 | 
					                    "prompt_tokens": None,
 | 
				
			||||||
 | 
					                    "completion_tokens": None,
 | 
				
			||||||
 | 
					                    "total_tokens": None,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
 | 
				
			||||||
 | 
					        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
				
			||||||
 | 
					        created: int = int(time.time())
 | 
				
			||||||
        if prompt.endswith("</s>") or max_tokens < 1:
 | 
					        if prompt.endswith("</s>") or max_tokens < 1:
 | 
				
			||||||
            yield prompt
 | 
					            yield {
 | 
				
			||||||
 | 
					                "id": completion_id,
 | 
				
			||||||
 | 
					                "object": "text_completion",
 | 
				
			||||||
 | 
					                "created": created,
 | 
				
			||||||
 | 
					                "model": self.model_path,
 | 
				
			||||||
 | 
					                "choices": [
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "text": prompt,
 | 
				
			||||||
 | 
					                        "index": 0,
 | 
				
			||||||
 | 
					                        "logprobs": None,
 | 
				
			||||||
 | 
					                        "finish_reason": "length",
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                "usage":
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        # TODO: need tokenize
 | 
				
			||||||
 | 
					                        "prompt_tokens": None
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            for i in range(max_tokens):
 | 
					            for i in range(max_tokens):
 | 
				
			||||||
                if prompt.endswith("</s>"):
 | 
					                if prompt.endswith("</s>"):
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    prompt = self._eval(prompt, 1, i != 0)
 | 
					                    prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
 | 
				
			||||||
                    yield prompt
 | 
					                    yield {
 | 
				
			||||||
 | 
					                        "id": completion_id,
 | 
				
			||||||
 | 
					                        "object": "text_completion",
 | 
				
			||||||
 | 
					                        "created": created,
 | 
				
			||||||
 | 
					                        "model": self.model_path,
 | 
				
			||||||
 | 
					                        "choices": [
 | 
				
			||||||
 | 
					                            {
 | 
				
			||||||
 | 
					                                "text": prompt,
 | 
				
			||||||
 | 
					                                "index": 0,
 | 
				
			||||||
 | 
					                                "logprobs": None,
 | 
				
			||||||
 | 
					                                "finish_reason": None,
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                        ],
 | 
				
			||||||
 | 
					                        "usage":
 | 
				
			||||||
 | 
					                            {
 | 
				
			||||||
 | 
					                                # TODO: need tokenize
 | 
				
			||||||
 | 
					                                "prompt_tokens": None
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def free(self):
 | 
					    def free(self):
 | 
				
			||||||
        bloom_free(self.ctx)
 | 
					        bloom_free(self.ctx)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue