LLM: Support generate(max_new_tokens=...), tokenize and decode for transformers-like API (#8283)
				
					
				
			* first push * fix pep8
This commit is contained in:
		
							parent
							
								
									11cd2a07e0
								
							
						
					
					
						commit
						2d14e593f0
					
				
					 3 changed files with 53 additions and 22 deletions
				
			
		| 
						 | 
					@ -31,6 +31,30 @@ class GenerationMixin:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Pass custom parameter values to 'generate' .
 | 
					    Pass custom parameter values to 'generate' .
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    def tokenize(self, text: str, add_bos: bool = True) -> List[int]:
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        Decode the id to words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param text: The text to be tokenized
 | 
				
			||||||
 | 
					        :param add_bos:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :return: list of ids that indicates the tokens
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        if isinstance(text, str):
 | 
				
			||||||
 | 
					            bstr = text.encode()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            bstr = text
 | 
				
			||||||
 | 
					        return self._tokenize(bstr, add_bos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self, tokens: List[int]) -> str:
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        Decode the id to words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param tokens: list of ids that indicates the tokens, mostly generated by generate
 | 
				
			||||||
 | 
					        :return: decoded string
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        return self.detokenize(tokens).decode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate(
 | 
					    def generate(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        inputs: Union[Optional[Sequence[int]], Sequence[gptneox_cpp.gptneox_token]]=None,
 | 
					        inputs: Union[Optional[Sequence[int]], Sequence[gptneox_cpp.gptneox_token]]=None,
 | 
				
			||||||
| 
						 | 
					@ -46,18 +70,18 @@ class GenerationMixin:
 | 
				
			||||||
        mirostat_mode: int = 0,
 | 
					        mirostat_mode: int = 0,
 | 
				
			||||||
        mirostat_tau: float = 5.0,
 | 
					        mirostat_tau: float = 5.0,
 | 
				
			||||||
        mirostat_eta: float = 0.1,
 | 
					        mirostat_eta: float = 0.1,
 | 
				
			||||||
        stop: Optional[Union[str, List[str]]]=[],
 | 
					        stop: Optional[Union[str, List[str]]]=[],  # TODO: rebase to support stopping_criteria
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> Union[Optional[Sequence[int]], Optional[Sequence[gptneox_cpp.gptneox_token]], None]:
 | 
					    ) -> Union[Optional[Sequence[int]], Optional[Sequence[gptneox_cpp.gptneox_token]], None]:
 | 
				
			||||||
        # TODO: modify docs
 | 
					        # TODO: modify docs
 | 
				
			||||||
        """Create a generator of tokens from a prompt.
 | 
					        """Create a generator of tokens from a prompt.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Examples:
 | 
					        Examples:
 | 
				
			||||||
            >>> llama = Llama("models/ggml-7b.bin")
 | 
					            >>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin",
 | 
				
			||||||
            >>> tokens = llama.tokenize(b"Hello, world!")
 | 
					                                                           model_family="llama")
 | 
				
			||||||
            >>> for token in llama.generate(tokens, top_k=40, top_p=0.95,
 | 
					            >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
 | 
				
			||||||
            >>>                             temp=1.0, repeat_penalty=1.1):
 | 
					            >>> tokens_id = llm.generate(tokens, max_new_tokens=32)
 | 
				
			||||||
            ...     print(llama.detokenize([token]))
 | 
					            >>> llm.decode(tokens_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            tokens: The prompt tokens.
 | 
					            tokens: The prompt tokens.
 | 
				
			||||||
| 
						 | 
					@ -70,8 +94,7 @@ class GenerationMixin:
 | 
				
			||||||
        Yields:
 | 
					        Yields:
 | 
				
			||||||
            The generated tokens.
 | 
					            The generated tokens.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        # TODO: stop & max_token
 | 
					        tokens = self._generate(tokens=inputs,
 | 
				
			||||||
        self._generate(tokens=inputs,
 | 
					 | 
				
			||||||
                                top_k=top_k,
 | 
					                                top_k=top_k,
 | 
				
			||||||
                                top_p=top_p,
 | 
					                                top_p=top_p,
 | 
				
			||||||
                                temp=temperature,
 | 
					                                temp=temperature,
 | 
				
			||||||
| 
						 | 
					@ -84,3 +107,11 @@ class GenerationMixin:
 | 
				
			||||||
                                mirostat_tau=mirostat_tau,
 | 
					                                mirostat_tau=mirostat_tau,
 | 
				
			||||||
                                mirostat_eta=mirostat_eta,
 | 
					                                mirostat_eta=mirostat_eta,
 | 
				
			||||||
                                **kwargs)
 | 
					                                **kwargs)
 | 
				
			||||||
 | 
					        res_list = []
 | 
				
			||||||
 | 
					        word_count = 0
 | 
				
			||||||
 | 
					        for token in tokens:
 | 
				
			||||||
 | 
					            if word_count > max_new_tokens:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            res_list.append(token)
 | 
				
			||||||
 | 
					            word_count += 1
 | 
				
			||||||
 | 
					        return res_list
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -225,7 +225,7 @@ class Gptneox:
 | 
				
			||||||
        if self.verbose:
 | 
					        if self.verbose:
 | 
				
			||||||
            print(gptneox_cpp.gptneox_print_system_info().decode("utf-8"), file=sys.stderr)
 | 
					            print(gptneox_cpp.gptneox_print_system_info().decode("utf-8"), file=sys.stderr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tokenize(
 | 
					    def _tokenize(
 | 
				
			||||||
        self, text: bytes, add_bos: bool = True
 | 
					        self, text: bytes, add_bos: bool = True
 | 
				
			||||||
    ) -> List[gptneox_cpp.gptneox_token]:
 | 
					    ) -> List[gptneox_cpp.gptneox_token]:
 | 
				
			||||||
        """Tokenize a string.
 | 
					        """Tokenize a string.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -250,7 +250,7 @@ class Llama(GenerationMixin):
 | 
				
			||||||
        self._token_nl = Llama.token_nl()
 | 
					        self._token_nl = Llama.token_nl()
 | 
				
			||||||
        self._token_eos = Llama.token_eos()
 | 
					        self._token_eos = Llama.token_eos()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
 | 
					    def _tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
 | 
				
			||||||
        """Tokenize a string.
 | 
					        """Tokenize a string.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue