diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index c70558c7..60680faf 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -175,7 +175,7 @@ class PromptLookupCandidateGenerator(): def init_look_up_table(self, input_ids: torch.LongTensor): - for ngram_size in range(self.max_matching_ngram_size, 0, -1): + for ngram_size in range(min(self.max_matching_ngram_size, input_ids.shape[1]), 0, -1): # Create sliding windows of size ngram_size windows = input_ids.cpu().unfold(dimension=1, size=ngram_size, step=1) for idx in range(windows.size(1)): @@ -315,11 +315,9 @@ def lookup_generate(self, if step == 0: # first token use full model tic = time.time() - output = self(input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - return_dict=True, - use_cache=True) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + output = self(**model_inputs, + return_dict=True) logits = output['logits'] logits = logits[:, -1:] logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :])