[LLM] llm transformers api support batch actions (#8288)
* llm transformers api support batch actions * align with transformer * meet comment
This commit is contained in:
		
							parent
							
								
									ea3cf6783e
								
							
						
					
					
						commit
						637b72f2ad
					
				
					 1 changed files with 88 additions and 31 deletions
				
			
		| 
						 | 
					@ -30,33 +30,78 @@ class GenerationMixin:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Pass custom parameter values to 'generate' .
 | 
					    Pass custom parameter values to 'generate' .
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def tokenize(self, text: str, add_bos: bool = True) -> List[int]:
 | 
					    def tokenize(self,
 | 
				
			||||||
 | 
					                 text: Union[str, List[str]],
 | 
				
			||||||
 | 
					                 add_bos: bool = True) -> List[int]:
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        Decode the id to words
 | 
					        Decode the id to words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :param text: The text to be tokenized
 | 
					        :param text: The text or batch of text to be tokenized
 | 
				
			||||||
        :param add_bos:
 | 
					        :param add_bos:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :return: list of ids that indicates the tokens
 | 
					        :return: list of ids that indicates the tokens
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        if isinstance(text, str):
 | 
					        is_batched = True if isinstance(text, (list, tuple)) else False
 | 
				
			||||||
            bstr = text.encode()
 | 
					        if not is_batched:
 | 
				
			||||||
        else:
 | 
					            text = [text]
 | 
				
			||||||
            bstr = text
 | 
					 | 
				
			||||||
        return self._tokenize(bstr, add_bos)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result = []
 | 
				
			||||||
 | 
					        for t in text:
 | 
				
			||||||
 | 
					            if isinstance(t, str):
 | 
				
			||||||
 | 
					                bstr = t.encode()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                bstr = t
 | 
				
			||||||
 | 
					            result.append(self._tokenize(bstr, add_bos))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not is_batched:
 | 
				
			||||||
 | 
					            result = result[0]
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    def decode(self, tokens: List[int]) -> str:
 | 
					    def decode(self, tokens: List[int]) -> str:
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        Decode the id to words
 | 
					        Decode the id to words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Examples:
 | 
				
			||||||
 | 
					            >>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin",
 | 
				
			||||||
 | 
					                                                           model_family="llama")
 | 
				
			||||||
 | 
					            >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
 | 
				
			||||||
 | 
					            >>> tokens_id = llm.generate(tokens, max_new_tokens=32)
 | 
				
			||||||
 | 
					            >>> llm.decode(tokens_id[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :param tokens: list of ids that indicates the tokens, mostly generated by generate
 | 
					        :param tokens: list of ids that indicates the tokens, mostly generated by generate
 | 
				
			||||||
        :return: decoded string
 | 
					        :return: decoded string
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        return self.detokenize(tokens).decode()
 | 
					        return self.detokenize(tokens).decode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def batch_decode(self,
 | 
				
			||||||
 | 
					               tokens: Union[List[int], List[List[int]]]) -> str:
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        Decode the id to words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param tokens: list or a batch of list of ids that indicates the tokens,
 | 
				
			||||||
 | 
					                mostly generated by generate
 | 
				
			||||||
 | 
					        :return: decoded string
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        is_batched = False
 | 
				
			||||||
 | 
					        if tokens is not None and len(tokens) > 0:
 | 
				
			||||||
 | 
					            if isinstance(tokens[0], Sequence):
 | 
				
			||||||
 | 
					                is_batched = True
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                tokens = [tokens]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        results = []
 | 
				
			||||||
 | 
					        for t in tokens:
 | 
				
			||||||
 | 
					            results.append(self.decode(t))
 | 
				
			||||||
 | 
					        if not is_batched:
 | 
				
			||||||
 | 
					            results = results[0]
 | 
				
			||||||
 | 
					        return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate(
 | 
					    def generate(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        inputs: Optional[Sequence[int]]=None,
 | 
					        inputs: Union[Optional[Sequence[int]],
 | 
				
			||||||
 | 
					                      Sequence[Sequence[int]]]=None,
 | 
				
			||||||
        max_new_tokens: int = 128,
 | 
					        max_new_tokens: int = 128,
 | 
				
			||||||
        top_k: int = 40,
 | 
					        top_k: int = 40,
 | 
				
			||||||
        top_p: float = 0.95,
 | 
					        top_p: float = 0.95,
 | 
				
			||||||
| 
						 | 
					@ -71,7 +116,9 @@ class GenerationMixin:
 | 
				
			||||||
        mirostat_eta: float = 0.1,
 | 
					        mirostat_eta: float = 0.1,
 | 
				
			||||||
        stop: Optional[Union[str, List[str]]]=[],  # TODO: rebase to support stopping_criteria
 | 
					        stop: Optional[Union[str, List[str]]]=[],  # TODO: rebase to support stopping_criteria
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> Union[Optional[Sequence[int]], None]:
 | 
					    ) -> Union[Optional[Sequence[int]],
 | 
				
			||||||
 | 
					               Sequence[Sequence[int]],
 | 
				
			||||||
 | 
					               None]:
 | 
				
			||||||
        # TODO: modify docs
 | 
					        # TODO: modify docs
 | 
				
			||||||
        """Create a generator of tokens from a prompt.
 | 
					        """Create a generator of tokens from a prompt.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -80,7 +127,7 @@ class GenerationMixin:
 | 
				
			||||||
                                                           model_family="llama")
 | 
					                                                           model_family="llama")
 | 
				
			||||||
            >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
 | 
					            >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
 | 
				
			||||||
            >>> tokens_id = llm.generate(tokens, max_new_tokens=32)
 | 
					            >>> tokens_id = llm.generate(tokens, max_new_tokens=32)
 | 
				
			||||||
            >>> llm.decode(tokens_id)
 | 
					            >>> llm.batch_decode(tokens_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            tokens: The prompt tokens.
 | 
					            tokens: The prompt tokens.
 | 
				
			||||||
| 
						 | 
					@ -93,24 +140,34 @@ class GenerationMixin:
 | 
				
			||||||
        Yields:
 | 
					        Yields:
 | 
				
			||||||
            The generated tokens.
 | 
					            The generated tokens.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        tokens = self._generate(tokens=inputs,
 | 
					        if inputs and len(inputs) > 0:
 | 
				
			||||||
                                top_k=top_k,
 | 
					            if not isinstance(inputs[0], Sequence):
 | 
				
			||||||
                                top_p=top_p,
 | 
					                inputs = [inputs]
 | 
				
			||||||
                                temp=temperature,
 | 
					        else:
 | 
				
			||||||
                                repeat_penalty=repetition_penalty,
 | 
					            return None
 | 
				
			||||||
                                reset=reset,
 | 
					
 | 
				
			||||||
                                frequency_penalty=frequency_penalty,
 | 
					        results = []
 | 
				
			||||||
                                presence_penalty=presence_penalty,
 | 
					        for input in inputs:
 | 
				
			||||||
                                tfs_z=tfs_z,
 | 
					            tokens = self._generate(tokens=input,
 | 
				
			||||||
                                mirostat_mode=mirostat_mode,
 | 
					                                    top_k=top_k,
 | 
				
			||||||
                                mirostat_tau=mirostat_tau,
 | 
					                                    top_p=top_p,
 | 
				
			||||||
                                mirostat_eta=mirostat_eta,
 | 
					                                    temp=temperature,
 | 
				
			||||||
                                **kwargs)
 | 
					                                    repeat_penalty=repetition_penalty,
 | 
				
			||||||
        res_list = []
 | 
					                                    reset=reset,
 | 
				
			||||||
        word_count = 0
 | 
					                                    frequency_penalty=frequency_penalty,
 | 
				
			||||||
        for token in tokens:
 | 
					                                    presence_penalty=presence_penalty,
 | 
				
			||||||
            if word_count > max_new_tokens:
 | 
					                                    tfs_z=tfs_z,
 | 
				
			||||||
                break
 | 
					                                    mirostat_mode=mirostat_mode,
 | 
				
			||||||
            res_list.append(token)
 | 
					                                    mirostat_tau=mirostat_tau,
 | 
				
			||||||
            word_count += 1
 | 
					                                    mirostat_eta=mirostat_eta,
 | 
				
			||||||
        return res_list
 | 
					                                    **kwargs)
 | 
				
			||||||
 | 
					            res_list = []
 | 
				
			||||||
 | 
					            word_count = 0
 | 
				
			||||||
 | 
					            for token in tokens:
 | 
				
			||||||
 | 
					                if word_count > max_new_tokens:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					                res_list.append(token)
 | 
				
			||||||
 | 
					                word_count += 1
 | 
				
			||||||
 | 
					            results.append(res_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return results
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue