Fix wrong attention mask and garbage output for inputs_embeds inputs during lookup generation (#11989)

* Fix garbage output for input_embeds inputs during lookup generation

* Fix on sliding windows

* Simplify code
This commit is contained in:
Yuwen Hu 2024-09-02 19:09:12 +08:00 committed by GitHub
parent 2f3d1bd0ec
commit 659d15defc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -175,7 +175,7 @@ class PromptLookupCandidateGenerator():
def init_look_up_table(self,
input_ids: torch.LongTensor):
for ngram_size in range(self.max_matching_ngram_size, 0, -1):
for ngram_size in range(min(self.max_matching_ngram_size, input_ids.shape[1]), 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.cpu().unfold(dimension=1, size=ngram_size, step=1)
for idx in range(windows.size(1)):
@ -315,11 +315,9 @@ def lookup_generate(self,
if step == 0:
# first token use full model
tic = time.time()
output = self(input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
return_dict=True,
use_cache=True)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
output = self(**model_inputs,
return_dict=True)
logits = output['logits']
logits = logits[:, -1:]
logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :])