Fix wrong attention mask and garbage output for inputs_embeds inputs during lookup generation (#11989)
* Fix garbage output for input_embeds inputs during lookup generation * Fix on sliding windows * Simplify code
This commit is contained in:
parent
2f3d1bd0ec
commit
659d15defc
1 changed files with 4 additions and 6 deletions
|
|
@ -175,7 +175,7 @@ class PromptLookupCandidateGenerator():
|
|||
|
||||
def init_look_up_table(self,
|
||||
input_ids: torch.LongTensor):
|
||||
for ngram_size in range(self.max_matching_ngram_size, 0, -1):
|
||||
for ngram_size in range(min(self.max_matching_ngram_size, input_ids.shape[1]), 0, -1):
|
||||
# Create sliding windows of size ngram_size
|
||||
windows = input_ids.cpu().unfold(dimension=1, size=ngram_size, step=1)
|
||||
for idx in range(windows.size(1)):
|
||||
|
|
@ -315,11 +315,9 @@ def lookup_generate(self,
|
|||
if step == 0:
|
||||
# first token use full model
|
||||
tic = time.time()
|
||||
output = self(input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
use_cache=True)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
output = self(**model_inputs,
|
||||
return_dict=True)
|
||||
logits = output['logits']
|
||||
logits = logits[:, -1:]
|
||||
logits[:, -1, :] = logits_processor(input_ids, logits[:, -1, :])
|
||||
|
|
|
|||
Loading…
Reference in a new issue