LLM: add stop words and enhance output for bloom pybinding (#8280)
This commit is contained in:
		
							parent
							
								
									6990328e5c
								
							
						
					
					
						commit
						f9e2bda04a
					
				
					 1 changed files with 105 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -47,6 +47,9 @@
 | 
			
		|||
 | 
			
		||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
import time
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Bloom:
 | 
			
		||||
| 
						 | 
				
			
			@ -81,6 +84,7 @@ class Bloom:
 | 
			
		|||
        Returns:
 | 
			
		||||
            A Bloom instance.
 | 
			
		||||
        """
 | 
			
		||||
        self.model_path = model_path
 | 
			
		||||
        self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
 | 
			
		||||
        invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
 | 
			
		||||
        self.n_ctx = n_ctx
 | 
			
		||||
| 
						 | 
				
			
			@ -91,15 +95,39 @@ class Bloom:
 | 
			
		|||
        self.last_n_tokens_size = last_n_tokens_size
 | 
			
		||||
        self.verbose = verbose
 | 
			
		||||
 | 
			
		||||
    def __call__(self, prompt: str, max_tokens: int, stream: bool = False):
 | 
			
		||||
    def __call__(self, prompt: str, max_tokens: int, stream: bool = False,
 | 
			
		||||
                 stop: Optional[List[str]] = []):
 | 
			
		||||
        if stream:
 | 
			
		||||
            return self.stream(prompt, max_tokens)
 | 
			
		||||
            return self.stream(prompt, max_tokens, stop)
 | 
			
		||||
        else:
 | 
			
		||||
            return self._eval(prompt, max_tokens, False)
 | 
			
		||||
            return self._eval(prompt, max_tokens, False, stop)
 | 
			
		||||
 | 
			
		||||
    def _eval(self, prompt: str, max_tokens: int, match_str: bool):
 | 
			
		||||
    def _eval(self, prompt: str, max_tokens: int, match_str: bool,
 | 
			
		||||
              stop: Optional[List[str]] = []):
 | 
			
		||||
        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
			
		||||
        created: int = int(time.time())
 | 
			
		||||
        if prompt.endswith("</s>") or max_tokens < 1:
 | 
			
		||||
            return prompt
 | 
			
		||||
            return {
 | 
			
		||||
                "id": completion_id,
 | 
			
		||||
                "object": "text_completion",
 | 
			
		||||
                "created": created,
 | 
			
		||||
                "model": self.model_path,
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    {
 | 
			
		||||
                        "text": prompt,
 | 
			
		||||
                        "index": 0,
 | 
			
		||||
                        "logprobs": None,
 | 
			
		||||
                        "finish_reason": "length",
 | 
			
		||||
                    }
 | 
			
		||||
                ],
 | 
			
		||||
                "usage":
 | 
			
		||||
                {
 | 
			
		||||
                    # TODO: need tokenize
 | 
			
		||||
                    "prompt_tokens": None,
 | 
			
		||||
                    "completion_tokens": None,
 | 
			
		||||
                    "total_tokens": None,
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        # use `buf` to store prompt and generated string,
 | 
			
		||||
        # assume the average length of words is less than 20 bytes
 | 
			
		||||
        buf = bytes((len(prompt) + max_tokens) * 20)
 | 
			
		||||
| 
						 | 
				
			
			@ -112,18 +140,85 @@ class Bloom:
 | 
			
		|||
                        prompt=bytes(prompt, encoding='utf-8'),
 | 
			
		||||
                        buf=buf)
 | 
			
		||||
        s = str(buf, encoding='utf-8').rstrip("\x00")
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def stream(self, prompt: str, max_tokens: int):
 | 
			
		||||
        text = s.split(prompt)[1]
 | 
			
		||||
        split_text = text
 | 
			
		||||
        if stop != []:
 | 
			
		||||
            for stop_word in stop:
 | 
			
		||||
                split_text = split_text.split(stop_word)[0]
 | 
			
		||||
        if split_text != text:
 | 
			
		||||
            finish_reason = "stop"
 | 
			
		||||
        else:
 | 
			
		||||
            finish_reason = None
 | 
			
		||||
        return {"id": completion_id,
 | 
			
		||||
                "object": "text_completion",
 | 
			
		||||
                "created": created,
 | 
			
		||||
                "model": self.model_path,
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    {
 | 
			
		||||
                        "text": prompt + split_text,
 | 
			
		||||
                        "index": 0,
 | 
			
		||||
                        "logprobs": None,
 | 
			
		||||
                        "finish_reason": finish_reason,
 | 
			
		||||
                    }
 | 
			
		||||
                ],
 | 
			
		||||
                "usage":
 | 
			
		||||
                {
 | 
			
		||||
                    # TODO: need tokenize
 | 
			
		||||
                    "prompt_tokens": None,
 | 
			
		||||
                    "completion_tokens": None,
 | 
			
		||||
                    "total_tokens": None,
 | 
			
		||||
                }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
    def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
 | 
			
		||||
        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
			
		||||
        created: int = int(time.time())
 | 
			
		||||
        if prompt.endswith("</s>") or max_tokens < 1:
 | 
			
		||||
            yield prompt
 | 
			
		||||
            yield {
 | 
			
		||||
                "id": completion_id,
 | 
			
		||||
                "object": "text_completion",
 | 
			
		||||
                "created": created,
 | 
			
		||||
                "model": self.model_path,
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    {
 | 
			
		||||
                        "text": prompt,
 | 
			
		||||
                        "index": 0,
 | 
			
		||||
                        "logprobs": None,
 | 
			
		||||
                        "finish_reason": "length",
 | 
			
		||||
                    }
 | 
			
		||||
                ],
 | 
			
		||||
                "usage":
 | 
			
		||||
                    {
 | 
			
		||||
                        # TODO: need tokenize
 | 
			
		||||
                        "prompt_tokens": None
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            for i in range(max_tokens):
 | 
			
		||||
                if prompt.endswith("</s>"):
 | 
			
		||||
                    break
 | 
			
		||||
                else:
 | 
			
		||||
                    prompt = self._eval(prompt, 1, i != 0)
 | 
			
		||||
                    yield prompt
 | 
			
		||||
                    prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
 | 
			
		||||
                    yield {
 | 
			
		||||
                        "id": completion_id,
 | 
			
		||||
                        "object": "text_completion",
 | 
			
		||||
                        "created": created,
 | 
			
		||||
                        "model": self.model_path,
 | 
			
		||||
                        "choices": [
 | 
			
		||||
                            {
 | 
			
		||||
                                "text": prompt,
 | 
			
		||||
                                "index": 0,
 | 
			
		||||
                                "logprobs": None,
 | 
			
		||||
                                "finish_reason": None,
 | 
			
		||||
                            }
 | 
			
		||||
                        ],
 | 
			
		||||
                        "usage":
 | 
			
		||||
                            {
 | 
			
		||||
                                # TODO: need tokenize
 | 
			
		||||
                                "prompt_tokens": None
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
    def free(self):
 | 
			
		||||
        bloom_free(self.ctx)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue