Fix lookahead sample error & add update strategy (#10894)
* Fix sample error & add update strategy * add mtl config * fix style * remove print
This commit is contained in:
		
							parent
							
								
									94b4e96fa6
								
							
						
					
					
						commit
						015d07a58f
					
				
					 1 changed files with 24 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -29,6 +29,7 @@ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteria
 | 
			
		|||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
 | 
			
		||||
    _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.utils import get_xpu_device_type
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger("ipex_llm.lookup")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -119,10 +120,16 @@ class PromptLookupCandidateGenerator():
 | 
			
		|||
        self,
 | 
			
		||||
        num_output_tokens: int = 10,
 | 
			
		||||
        max_matching_ngram_size: int = None,
 | 
			
		||||
        device: str = "arc",
 | 
			
		||||
    ):
 | 
			
		||||
        self.num_output_tokens = num_output_tokens
 | 
			
		||||
        self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
 | 
			
		||||
 | 
			
		||||
        if device == "mtl":
 | 
			
		||||
            self.max_candidates = 3
 | 
			
		||||
        else:
 | 
			
		||||
            self.max_candidates = 9
 | 
			
		||||
 | 
			
		||||
        invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
 | 
			
		||||
                          "Invalid max_matching_ngram_size or num_output_tokens")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -183,25 +190,18 @@ class PromptLookupCandidateGenerator():
 | 
			
		|||
        # so returning None
 | 
			
		||||
        return candidate_input_ids, None
 | 
			
		||||
 | 
			
		||||
    def update_candidate_strategy(self, input_ids: torch.LongTensor,
 | 
			
		||||
                                  scores: torch.FloatTensor, num_matches: int):
 | 
			
		||||
    def update_candidate_strategy(self, candidate_num: int, num_matches: int):
 | 
			
		||||
        """
 | 
			
		||||
        Updates the candidate generation strategy based on the outcomes.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
 | 
			
		||||
                Indices of input sequence tokens in the vocabulary.
 | 
			
		||||
                [What are input IDs?](../glossary#input-ids)
 | 
			
		||||
            scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
 | 
			
		||||
                config.vocab_size)`):
 | 
			
		||||
                Prediction scores of a language modeling head. These can be logits for each
 | 
			
		||||
                vocabulary when not using beam search or log softmax for each vocabulary
 | 
			
		||||
                token when using beam search
 | 
			
		||||
            num_matches (`int`):
 | 
			
		||||
                The number of matches between the candidate sequences and the model predictions.
 | 
			
		||||
        """
 | 
			
		||||
        # Currently does nothing
 | 
			
		||||
        return
 | 
			
		||||
        if num_matches == self.num_output_tokens:
 | 
			
		||||
            self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
 | 
			
		||||
        elif candidate_num > num_matches:
 | 
			
		||||
            self.num_output_tokens = max(self.num_output_tokens - 1, 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
| 
						 | 
				
			
			@ -217,9 +217,12 @@ def lookup_generate(self,
 | 
			
		|||
        model_kwargs = _prepare_generate_args(self, inputs, generation_config,
 | 
			
		||||
                                              **sampling_kwargs)
 | 
			
		||||
 | 
			
		||||
    device_name = get_xpu_device_type(input_ids)
 | 
			
		||||
 | 
			
		||||
    candidates_generator = PromptLookupCandidateGenerator(
 | 
			
		||||
        num_output_tokens=num_output_tokens,
 | 
			
		||||
        max_matching_ngram_size=max_matching_ngram_size)
 | 
			
		||||
        max_matching_ngram_size=max_matching_ngram_size,
 | 
			
		||||
        device=device_name)
 | 
			
		||||
 | 
			
		||||
    step = 0
 | 
			
		||||
    step_verify = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -291,6 +294,7 @@ def lookup_generate(self,
 | 
			
		|||
                                                        top_k=generation_config.top_k,
 | 
			
		||||
                                                        top_p=generation_config.top_p,
 | 
			
		||||
                                                        temperature=generation_config.temperature)
 | 
			
		||||
                output_ids = output_ids.transpose(0, 1)
 | 
			
		||||
            else:
 | 
			
		||||
                output_ids = greedy(logits)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -303,13 +307,14 @@ def lookup_generate(self,
 | 
			
		|||
            # Drafts start from [1, k]
 | 
			
		||||
            # Verified output start from [0, k - 1]
 | 
			
		||||
            # including the one generated by the base model
 | 
			
		||||
            max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
 | 
			
		||||
            max_matched = max_matched.sum(-1).item() + 1
 | 
			
		||||
            n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
 | 
			
		||||
                         .cumsum(-1) == 0).sum(-1).item()
 | 
			
		||||
            max_matched = n_matches + 1
 | 
			
		||||
 | 
			
		||||
            max_of_max_matched = output_ids.size(1)
 | 
			
		||||
            # Accept number is max_matched, min is 1
 | 
			
		||||
            self.accept_num.append(max_matched)
 | 
			
		||||
            self.n_matched += max_matched - 1
 | 
			
		||||
            self.n_matched += n_matches
 | 
			
		||||
            self.n_drafted += candidate_length
 | 
			
		||||
 | 
			
		||||
            # Clean up target model KV cache
 | 
			
		||||
| 
						 | 
				
			
			@ -319,6 +324,9 @@ def lookup_generate(self,
 | 
			
		|||
                past_key_values = _crop_past_key_values(self, past_key_values,
 | 
			
		||||
                                                        new_cache_size)
 | 
			
		||||
 | 
			
		||||
            # Update the candidate generation strategy if needed
 | 
			
		||||
            candidates_generator.update_candidate_strategy(candidate_length, n_matches)
 | 
			
		||||
 | 
			
		||||
            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
			
		||||
 | 
			
		||||
            step += output_ids.size(1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue