Update lookahead strategy (#11021)
* update lookahead strategy * remove lines * fix python style check
This commit is contained in:
		
							parent
							
								
									1d73fc8106
								
							
						
					
					
						commit
						93d40ab127
					
				
					 1 changed files with 19 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -24,6 +24,7 @@ from typing import Callable, List, Optional, Tuple
 | 
			
		|||
import torch
 | 
			
		||||
import time
 | 
			
		||||
import copy
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
 | 
			
		||||
| 
						 | 
				
			
			@ -127,8 +128,10 @@ class PromptLookupCandidateGenerator():
 | 
			
		|||
 | 
			
		||||
        if device == "mtl":
 | 
			
		||||
            self.max_candidates = 3
 | 
			
		||||
            self.min_candidates = 0
 | 
			
		||||
        else:
 | 
			
		||||
            self.max_candidates = 9
 | 
			
		||||
            self.min_candidates = 0
 | 
			
		||||
 | 
			
		||||
        invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
 | 
			
		||||
                          "Invalid max_matching_ngram_size or num_output_tokens")
 | 
			
		||||
| 
						 | 
				
			
			@ -148,6 +151,8 @@ class PromptLookupCandidateGenerator():
 | 
			
		|||
            `torch.LongTensor` of shape `(num_candidates, candidate_length)`:
 | 
			
		||||
            The candidate sequences to be tried.
 | 
			
		||||
        """
 | 
			
		||||
        if self.num_output_tokens == 0:
 | 
			
		||||
            return input_ids, None
 | 
			
		||||
        input_length = input_ids.size(1)
 | 
			
		||||
 | 
			
		||||
        chosen_ids = None
 | 
			
		||||
| 
						 | 
				
			
			@ -190,7 +195,7 @@ class PromptLookupCandidateGenerator():
 | 
			
		|||
        # so returning None
 | 
			
		||||
        return candidate_input_ids, None
 | 
			
		||||
 | 
			
		||||
    def update_candidate_strategy(self, candidate_num: int, num_matches: int):
 | 
			
		||||
    def update_candidate_strategy(self, candidate_num: int, num_matches: int, accept_rate: float):
 | 
			
		||||
        """
 | 
			
		||||
        Updates the candidate generation strategy based on the outcomes.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -198,10 +203,16 @@ class PromptLookupCandidateGenerator():
 | 
			
		|||
            num_matches (`int`):
 | 
			
		||||
                The number of matches between the candidate sequences and the model predictions.
 | 
			
		||||
        """
 | 
			
		||||
        if num_matches == self.num_output_tokens:
 | 
			
		||||
        if self.num_output_tokens == 0:
 | 
			
		||||
            ran = random.random() - 0.15
 | 
			
		||||
            if ran <= accept_rate:
 | 
			
		||||
                self.num_output_tokens = 1
 | 
			
		||||
        elif num_matches == self.num_output_tokens:
 | 
			
		||||
            self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
 | 
			
		||||
        elif candidate_num > num_matches:
 | 
			
		||||
            self.num_output_tokens = max(self.num_output_tokens - 1, 1)
 | 
			
		||||
            ran = random.random() + 0.1 * (candidate_num - num_matches)
 | 
			
		||||
            if ran > accept_rate:
 | 
			
		||||
                self.num_output_tokens = max(self.num_output_tokens - 1, self.min_candidates)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
| 
						 | 
				
			
			@ -228,6 +239,7 @@ def lookup_generate(self,
 | 
			
		|||
    step_verify = 0
 | 
			
		||||
 | 
			
		||||
    clear_benchmarks(self)
 | 
			
		||||
    self.accept_rate = []
 | 
			
		||||
 | 
			
		||||
    past_key_values = None
 | 
			
		||||
    input_len = input_ids.shape[1]
 | 
			
		||||
| 
						 | 
				
			
			@ -324,8 +336,11 @@ def lookup_generate(self,
 | 
			
		|||
                past_key_values = _crop_past_key_values(self, past_key_values,
 | 
			
		||||
                                                        new_cache_size)
 | 
			
		||||
 | 
			
		||||
            accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
 | 
			
		||||
            self.accept_rate.append(accept_rate)
 | 
			
		||||
            # Update the candidate generation strategy if needed
 | 
			
		||||
            candidates_generator.update_candidate_strategy(candidate_length, n_matches)
 | 
			
		||||
            candidates_generator.update_candidate_strategy(candidate_length, n_matches,
 | 
			
		||||
                                                           accept_rate)
 | 
			
		||||
 | 
			
		||||
            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue