parent
							
								
									375174af33
								
							
						
					
					
						commit
						b7948671de
					
				
					 2 changed files with 48 additions and 18 deletions
				
			
		| 
						 | 
					@ -99,6 +99,10 @@ def generate(
 | 
				
			||||||
GenerationMixin.generate = generate
 | 
					GenerationMixin.generate = generate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def tensor2key(key_tensor: torch.LongTensor):
 | 
				
			||||||
 | 
					    return tuple(key_tensor.tolist())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# This class is copied from https://github.com/huggingface/transformers/blob/main/src
 | 
					# This class is copied from https://github.com/huggingface/transformers/blob/main/src
 | 
				
			||||||
# /transformers/generation/candidate_generator.py
 | 
					# /transformers/generation/candidate_generator.py
 | 
				
			||||||
class PromptLookupCandidateGenerator():
 | 
					class PromptLookupCandidateGenerator():
 | 
				
			||||||
| 
						 | 
					@ -133,9 +137,34 @@ class PromptLookupCandidateGenerator():
 | 
				
			||||||
            self.max_candidates = 9
 | 
					            self.max_candidates = 9
 | 
				
			||||||
            self.min_candidates = 0
 | 
					            self.min_candidates = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.lookup_table = {}
 | 
				
			||||||
        invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
 | 
					        invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
 | 
				
			||||||
                          "Invalid max_matching_ngram_size or num_output_tokens")
 | 
					                          "Invalid max_matching_ngram_size or num_output_tokens")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_look_up_table(self,
 | 
				
			||||||
 | 
					                           input_ids: torch.LongTensor):
 | 
				
			||||||
 | 
					        for ngram_size in range(self.max_matching_ngram_size, 0, -1):
 | 
				
			||||||
 | 
					            # Create sliding windows of size ngram_size
 | 
				
			||||||
 | 
					            windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
 | 
				
			||||||
 | 
					            for idx in range(windows.size(1)):
 | 
				
			||||||
 | 
					                window = tensor2key(windows[0, idx])
 | 
				
			||||||
 | 
					                if window not in self.lookup_table:
 | 
				
			||||||
 | 
					                    self.lookup_table[window] = idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_look_up_table(self,
 | 
				
			||||||
 | 
					                             new_input_ids: torch.LongTensor):
 | 
				
			||||||
 | 
					        # Maintain a look up table
 | 
				
			||||||
 | 
					        window = tensor2key(new_input_ids[0, -self.max_matching_ngram_size:])
 | 
				
			||||||
 | 
					        for ngram_size in range(self.max_matching_ngram_size):
 | 
				
			||||||
 | 
					            if window[ngram_size:] not in self.lookup_table:
 | 
				
			||||||
 | 
					                self.lookup_table[window[ngram_size:]] = \
 | 
				
			||||||
 | 
					                    new_input_ids.size(1)-self.max_matching_ngram_size+ngram_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_n_gram_idx(self,
 | 
				
			||||||
 | 
					                       ngram_tensor: torch.LongTensor):
 | 
				
			||||||
 | 
					        key = tensor2key(ngram_tensor)
 | 
				
			||||||
 | 
					        return self.lookup_table[key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_candidates(self,
 | 
					    def get_candidates(self,
 | 
				
			||||||
                       input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
 | 
					                       input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
 | 
				
			||||||
                                                            Optional[torch.FloatTensor]]:
 | 
					                                                            Optional[torch.FloatTensor]]:
 | 
				
			||||||
| 
						 | 
					@ -156,31 +185,20 @@ class PromptLookupCandidateGenerator():
 | 
				
			||||||
        input_length = input_ids.size(1)
 | 
					        input_length = input_ids.size(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        chosen_ids = None
 | 
					        chosen_ids = None
 | 
				
			||||||
        match_found = False
 | 
					 | 
				
			||||||
        for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
 | 
					        for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
 | 
				
			||||||
            # Create sliding windows of size ngram_size
 | 
					 | 
				
			||||||
            windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Convert ngram to a tensor for comparison
 | 
					            # Convert ngram to a tensor for comparison
 | 
				
			||||||
            ngram_tensor = input_ids[0, -ngram_size:]
 | 
					            ngram_tensor = input_ids[0, -ngram_size:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Find where the windows match the ngram
 | 
					            # # Get the indices of matches
 | 
				
			||||||
            matches = (windows == ngram_tensor).all(dim=2)
 | 
					            idx = self.get_n_gram_idx(ngram_tensor)
 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Get the indices of matches
 | 
					 | 
				
			||||||
            match_indices = matches.nonzero(as_tuple=True)[1]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Iterate through match indices to find a valid continuation
 | 
					            # Iterate through match indices to find a valid continuation
 | 
				
			||||||
            for idx in match_indices:
 | 
					            start_idx = idx + ngram_size
 | 
				
			||||||
                start_idx = idx + ngram_size
 | 
					            end_idx = start_idx + self.num_output_tokens
 | 
				
			||||||
                end_idx = start_idx + self.num_output_tokens
 | 
					            end_idx = min(end_idx, input_length)
 | 
				
			||||||
                end_idx = min(end_idx, input_length)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if start_idx < end_idx:
 | 
					            if start_idx < end_idx:
 | 
				
			||||||
                    chosen_ids = input_ids[0, start_idx:end_idx]
 | 
					                chosen_ids = input_ids[0, start_idx:end_idx]
 | 
				
			||||||
                    match_found = True
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
            if match_found:
 | 
					 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if chosen_ids is None or len(chosen_ids) == 0:
 | 
					        if chosen_ids is None or len(chosen_ids) == 0:
 | 
				
			||||||
| 
						 | 
					@ -267,6 +285,9 @@ def lookup_generate(self,
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                output_ids = greedy(logits)
 | 
					                output_ids = greedy(logits)
 | 
				
			||||||
            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
					            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            candidates_generator.init_look_up_table(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            past_key_values = output['past_key_values']
 | 
					            past_key_values = output['past_key_values']
 | 
				
			||||||
            step += 1
 | 
					            step += 1
 | 
				
			||||||
            if self.device.type == 'xpu':
 | 
					            if self.device.type == 'xpu':
 | 
				
			||||||
| 
						 | 
					@ -319,9 +340,13 @@ def lookup_generate(self,
 | 
				
			||||||
            # Drafts start from [1, k]
 | 
					            # Drafts start from [1, k]
 | 
				
			||||||
            # Verified output start from [0, k - 1]
 | 
					            # Verified output start from [0, k - 1]
 | 
				
			||||||
            # including the one generated by the base model
 | 
					            # including the one generated by the base model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
 | 
					            n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
 | 
				
			||||||
                         .cumsum(-1) == 0).sum(-1).item()
 | 
					                         .cumsum(-1) == 0).sum(-1).item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            max_matched = n_matches + 1
 | 
					            max_matched = n_matches + 1
 | 
				
			||||||
 | 
					            mot = time.time()
 | 
				
			||||||
 | 
					            self.match_time.append(mot-toc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            max_of_max_matched = output_ids.size(1)
 | 
					            max_of_max_matched = output_ids.size(1)
 | 
				
			||||||
            # Accept number is max_matched, min is 1
 | 
					            # Accept number is max_matched, min is 1
 | 
				
			||||||
| 
						 | 
					@ -343,9 +368,12 @@ def lookup_generate(self,
 | 
				
			||||||
                                                           accept_rate)
 | 
					                                                           accept_rate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
					            input_ids = torch.cat((input_ids, output_ids), dim=-1)
 | 
				
			||||||
 | 
					            candidates_generator.update_look_up_table(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            step += output_ids.size(1)
 | 
					            step += output_ids.size(1)
 | 
				
			||||||
            step_verify += 1
 | 
					            step_verify += 1
 | 
				
			||||||
 | 
					            pot = time.time()
 | 
				
			||||||
 | 
					            self.post_time.append(pot-mot)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Stop on eos and remove content after eos
 | 
					        # Stop on eos and remove content after eos
 | 
				
			||||||
        output_ids_list = output_ids[0].tolist()
 | 
					        output_ids_list = output_ids[0].tolist()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -162,6 +162,8 @@ def clear_benchmarks(self):
 | 
				
			||||||
    self.generate_time = []
 | 
					    self.generate_time = []
 | 
				
			||||||
    self.draft_time = []
 | 
					    self.draft_time = []
 | 
				
			||||||
    self.verify_time = []
 | 
					    self.verify_time = []
 | 
				
			||||||
 | 
					    self.match_time = []
 | 
				
			||||||
 | 
					    self.post_time = []
 | 
				
			||||||
    self.draft_num = []
 | 
					    self.draft_num = []
 | 
				
			||||||
    self.accept_num = []
 | 
					    self.accept_num = []
 | 
				
			||||||
    self.n_drafted = 0
 | 
					    self.n_drafted = 0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue