diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 41d40f28..d5423848 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -29,6 +29,7 @@ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteria from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers.utils import get_xpu_device_type logger = logging.getLogger("ipex_llm.lookup") @@ -119,10 +120,16 @@ class PromptLookupCandidateGenerator(): self, num_output_tokens: int = 10, max_matching_ngram_size: int = None, + device: str = "arc", ): self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 + if device == "mtl": + self.max_candidates = 3 + else: + self.max_candidates = 9 + invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, "Invalid max_matching_ngram_size or num_output_tokens") @@ -183,25 +190,18 @@ class PromptLookupCandidateGenerator(): # so returning None return candidate_input_ids, None - def update_candidate_strategy(self, input_ids: torch.LongTensor, - scores: torch.FloatTensor, num_matches: int): + def update_candidate_strategy(self, candidate_num: int, num_matches: int): """ Updates the candidate generation strategy based on the outcomes. Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - [What are input IDs?](../glossary#input-ids) - scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, - config.vocab_size)`): - Prediction scores of a language modeling head. These can be logits for each - vocabulary when not using beam search or log softmax for each vocabulary - token when using beam search num_matches (`int`): The number of matches between the candidate sequences and the model predictions. """ - # Currently does nothing - return + if num_matches == self.num_output_tokens: + self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates) + elif candidate_num > num_matches: + self.num_output_tokens = max(self.num_output_tokens - 1, 1) @torch.no_grad() @@ -217,9 +217,12 @@ def lookup_generate(self, model_kwargs = _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs) + device_name = get_xpu_device_type(input_ids) + candidates_generator = PromptLookupCandidateGenerator( num_output_tokens=num_output_tokens, - max_matching_ngram_size=max_matching_ngram_size) + max_matching_ngram_size=max_matching_ngram_size, + device=device_name) step = 0 step_verify = 0 @@ -291,6 +294,7 @@ def lookup_generate(self, top_k=generation_config.top_k, top_p=generation_config.top_p, temperature=generation_config.temperature) + output_ids = output_ids.transpose(0, 1) else: output_ids = greedy(logits) @@ -303,13 +307,14 @@ def lookup_generate(self, # Drafts start from [1, k] # Verified output start from [0, k - 1] # including the one generated by the base model - max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0) - max_matched = max_matched.sum(-1).item() + 1 + n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:]) + .cumsum(-1) == 0).sum(-1).item() + max_matched = n_matches + 1 max_of_max_matched = output_ids.size(1) # Accept number is max_matched, min is 1 self.accept_num.append(max_matched) - self.n_matched += max_matched - 1 + self.n_matched += n_matches self.n_drafted += candidate_length # Clean up target model KV cache @@ -319,6 +324,9 @@ def lookup_generate(self, past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size) + # Update the candidate generation strategy if needed + candidates_generator.update_candidate_strategy(candidate_length, n_matches) + input_ids = torch.cat((input_ids, output_ids), dim=-1) step += output_ids.size(1)