From 659d15defc61ac2d234de6f7deb3f07bc942b70e Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:09:12 +0800 Subject: [PATCH] Fix wrong attention mask and garbage output for `inputs_embeds` inputs during lookup generation (#11989) * Fix garbage output for input_embeds inputs during lookup generation * Fix on sliding windows * Simplify code --- python/llm/src/ipex_llm/transformers/lookup.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index c70558c7..60680faf 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -175,7 +175,7 @@ class PromptLookupCandidateGenerator(): def init_look_up_table(self, input_ids: torch.LongTensor): - for ngram_size in range(self.max_matching_ngram_size, 0, -1): + for ngram_size in range(min(self.max_matching_ngram_size, input_ids.shape[1]), 0, -1): # Create sliding windows of size ngram_size windows = input_ids.cpu().unfold(dimension=1, size=ngram_size, step=1) for idx in range(windows.size(1)): @@ -315,11 +315,9 @@ def lookup_generate(self, if step == 0: # first token use full model tic = time.time() - output = self(input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - return_dict=True, - use_cache=True) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + output = self(**model_inputs, + return_dict=True) logits = output['logits'] logits = logits[:, -1:] logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :])