optimize minicpm-v-2_6 repetition penalty (#11763)

This commit is contained in:
Yishuo Wang 2024-08-12 14:10:10 +08:00 committed by GitHub
parent fac4c01a6e
commit 57d177738d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -19,23 +19,14 @@ import torch
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
# todo
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be
# multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
# ipex llm changes start: call scatter on CPU
device = scores.device
scores = scores.to('cpu')
input_ids = input_ids.to('cpu')
score = score.to('cpu')
scores.scatter_(1, input_ids, score)
scores = scores.to(device)
# ipex llm changes end
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores