* initially add text_generation_webui support * add env requirements install * add necessary dependencies * update for starting webui * update shared and noted to place models * update heading of part3 * meet comments * add copyright license * remove extensions * convert tutorial to windows side * add warm-up to optimize performance
		
			
				
	
	
		
			113 lines
		
	
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			113 lines
		
	
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
# Copyright 2016 The BigDL Authors.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
# This file is adapted from
 | 
						|
# https://github.com/huggingface/transformers/pull/27557
 | 
						|
 | 
						|
 | 
						|
import math
 | 
						|
 | 
						|
import torch
 | 
						|
from transformers.generation.logits_process import LogitsProcessor
 | 
						|
from transformers.utils import add_start_docstrings
 | 
						|
 | 
						|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
 | 
						|
    Args:
 | 
						|
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
 | 
						|
            Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
 | 
						|
        scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
 | 
						|
            Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
 | 
						|
            search or log softmax for each vocabulary token when using beam search
 | 
						|
 | 
						|
    Return:
 | 
						|
        `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
 | 
						|
 | 
						|
"""
 | 
						|
 | 
						|
 | 
						|
class GrammarConstrainedLogitsProcessor(LogitsProcessor):
 | 
						|
    def __init__(self, grammar_constraint):
 | 
						|
        self.last_size = None
 | 
						|
        self.grammar_constraint = grammar_constraint
 | 
						|
        self.batch_stacks = None
 | 
						|
 | 
						|
    def filter_logits(self, logits, device):
 | 
						|
        # resolve each stack to a tensor of True/False for each token
 | 
						|
        # indicating acceptance
 | 
						|
        # acceptance = self.grammar_acceptor.filter_vocab(self.stacks, device)
 | 
						|
        acceptance = self.grammar_constraint.batch_filter_vocab(self.batch_stacks, device)
 | 
						|
        # logger.debug(acceptance)
 | 
						|
        # Logits to -inf where False
 | 
						|
        logits[~acceptance] = -math.inf
 | 
						|
 | 
						|
    # TODO: batching
 | 
						|
    def process_logits(self, input_ids, scores, parse_start_index=None):
 | 
						|
        """
 | 
						|
        :param input_ids:
 | 
						|
        :param scores:
 | 
						|
        :param parse_start_index: default None, which means generate from scratch. Set to 0 to parse all input_ids
 | 
						|
        :return:
 | 
						|
        """
 | 
						|
        # we dynamically create stacks at the first call, so that we know the batch size and beam size
 | 
						|
        if self.batch_stacks is None:
 | 
						|
            self.batch_stacks = [self.grammar_constraint.init_stacks() for _ in range(len(input_ids))]
 | 
						|
 | 
						|
        # if self.last_size is not set (which would be the case when processing the first token).
 | 
						|
        # In this case, do nothing.
 | 
						|
        if self.last_size is None:
 | 
						|
            prefix_to_parse = [
 | 
						|
                single_input_ids[parse_start_index:] if parse_start_index is not None else []
 | 
						|
                for single_input_ids in input_ids
 | 
						|
            ]
 | 
						|
            # self.grammar_acceptor.accept_token_ids(prefix_to_parse, self.stacks)
 | 
						|
            self.batch_stacks = [
 | 
						|
                self.grammar_constraint.accept_token_ids(prefix, stack)
 | 
						|
                for prefix, stack in zip(prefix_to_parse, self.batch_stacks)
 | 
						|
            ]
 | 
						|
        #  if the length of the current input IDs (input_ids[0]) is exactly one more than self.last_size.
 | 
						|
        #  This is expected in a scenario where inputs are processed incrementally, one token at a time.
 | 
						|
        elif len(input_ids[0]) == self.last_size + 1:
 | 
						|
            # self.stacks = self.grammar_acceptor.accept_token_id(input_ids[0][-1], self.stacks)
 | 
						|
            self.batch_stacks = [
 | 
						|
                self.grammar_constraint.accept_token_id(single_input_ids[-1], stack)
 | 
						|
                for single_input_ids, stack in zip(input_ids, self.batch_stacks)
 | 
						|
            ]
 | 
						|
        #  ensure that the input size is consistent with the expected incremental processing
 | 
						|
        #  (i.e., one token at a time).
 | 
						|
        else:
 | 
						|
            # here we check if the input_ids are one token longer than the last time we processed
 | 
						|
            # but we don't check if input_ids are actually valid.
 | 
						|
            # Imagine a scenario where we generate 10 tokens, then we replace the 10 generated tokens with 10 new tokens.
 | 
						|
            # In this case, the input_ids will be consistent with the last_size, but the input_ids are not valid.
 | 
						|
            # However, should we really check if the input_ids are valid here?
 | 
						|
            # If we do, then we need to reparse the whole input_ids at each call, which is not efficient.
 | 
						|
            # Maybe we should just trust the user to provide valid input_ids?
 | 
						|
            # The conclusion is that, we assume the input_ids are valid, and our generation will be correct.
 | 
						|
            # If the input_ids are not valid, then the generation result will be wrong and we don't take responsibility for that.
 | 
						|
            raise RuntimeError(
 | 
						|
                "Input ID's length is inconsistent with the current state of "
 | 
						|
                "the GrammarConstrainedLogitsProcessor. If you want to process "
 | 
						|
                "another input sequence, please instantiate a new "
 | 
						|
                "GrammarConstrainedLogitsProcessor."
 | 
						|
            )
 | 
						|
 | 
						|
        self.filter_logits(scores, scores.device)
 | 
						|
 | 
						|
        self.last_size = len(input_ids[0])
 | 
						|
        return scores
 | 
						|
 | 
						|
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
 | 
						|
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 | 
						|
        return self.process_logits(input_ids, scores)
 |