diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 1eaaf83a..36815902 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -99,6 +99,10 @@ def generate( GenerationMixin.generate = generate +def tensor2key(key_tensor: torch.LongTensor): + return tuple(key_tensor.tolist()) + + # This class is copied from https://github.com/huggingface/transformers/blob/main/src # /transformers/generation/candidate_generator.py class PromptLookupCandidateGenerator(): @@ -133,9 +137,34 @@ class PromptLookupCandidateGenerator(): self.max_candidates = 9 self.min_candidates = 0 + self.lookup_table = {} invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, "Invalid max_matching_ngram_size or num_output_tokens") + def init_look_up_table(self, + input_ids: torch.LongTensor): + for ngram_size in range(self.max_matching_ngram_size, 0, -1): + # Create sliding windows of size ngram_size + windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) + for idx in range(windows.size(1)): + window = tensor2key(windows[0, idx]) + if window not in self.lookup_table: + self.lookup_table[window] = idx + + def update_look_up_table(self, + new_input_ids: torch.LongTensor): + # Maintain a look up table + window = tensor2key(new_input_ids[0, -self.max_matching_ngram_size:]) + for ngram_size in range(self.max_matching_ngram_size): + if window[ngram_size:] not in self.lookup_table: + self.lookup_table[window[ngram_size:]] = \ + new_input_ids.size(1)-self.max_matching_ngram_size+ngram_size + + def get_n_gram_idx(self, + ngram_tensor: torch.LongTensor): + key = tensor2key(ngram_tensor) + return self.lookup_table[key] + def get_candidates(self, input_ids: torch.LongTensor)-> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: @@ -156,31 +185,20 @@ class PromptLookupCandidateGenerator(): input_length = input_ids.size(1) chosen_ids = None - match_found = False for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): - # Create sliding windows of size ngram_size - windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) - # Convert ngram to a tensor for comparison ngram_tensor = input_ids[0, -ngram_size:] - # Find where the windows match the ngram - matches = (windows == ngram_tensor).all(dim=2) - - # Get the indices of matches - match_indices = matches.nonzero(as_tuple=True)[1] + # # Get the indices of matches + idx = self.get_n_gram_idx(ngram_tensor) # Iterate through match indices to find a valid continuation - for idx in match_indices: - start_idx = idx + ngram_size - end_idx = start_idx + self.num_output_tokens - end_idx = min(end_idx, input_length) + start_idx = idx + ngram_size + end_idx = start_idx + self.num_output_tokens + end_idx = min(end_idx, input_length) - if start_idx < end_idx: - chosen_ids = input_ids[0, start_idx:end_idx] - match_found = True - break - if match_found: + if start_idx < end_idx: + chosen_ids = input_ids[0, start_idx:end_idx] break if chosen_ids is None or len(chosen_ids) == 0: @@ -267,6 +285,9 @@ def lookup_generate(self, else: output_ids = greedy(logits) input_ids = torch.cat((input_ids, output_ids), dim=-1) + + candidates_generator.init_look_up_table(input_ids) + past_key_values = output['past_key_values'] step += 1 if self.device.type == 'xpu': @@ -319,9 +340,13 @@ def lookup_generate(self, # Drafts start from [1, k] # Verified output start from [0, k - 1] # including the one generated by the base model + n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:]) .cumsum(-1) == 0).sum(-1).item() + max_matched = n_matches + 1 + mot = time.time() + self.match_time.append(mot-toc) max_of_max_matched = output_ids.size(1) # Accept number is max_matched, min is 1 @@ -343,9 +368,12 @@ def lookup_generate(self, accept_rate) input_ids = torch.cat((input_ids, output_ids), dim=-1) + candidates_generator.update_look_up_table(input_ids) step += output_ids.size(1) step_verify += 1 + pot = time.time() + self.post_time.append(pot-mot) # Stop on eos and remove content after eos output_ids_list = output_ids[0].tolist() diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 2f123659..6d2e0842 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -162,6 +162,8 @@ def clear_benchmarks(self): self.generate_time = [] self.draft_time = [] self.verify_time = [] + self.match_time = [] + self.post_time = [] self.draft_num = [] self.accept_num = [] self.n_drafted = 0