optimize minicpm-v-2_6 repetition penalty (#11763)
This commit is contained in:
parent
fac4c01a6e
commit
57d177738d
1 changed files with 7 additions and 16 deletions
|
|
@ -19,23 +19,14 @@ import torch
|
|||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
||||
# todo
|
||||
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be
|
||||
# multiplied to reduce the token probabilities
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
# ipex llm changes start: call scatter on CPU
|
||||
device = scores.device
|
||||
scores = scores.to('cpu')
|
||||
input_ids = input_ids.to('cpu')
|
||||
score = score.to('cpu')
|
||||
scores.scatter_(1, input_ids, score)
|
||||
scores = scores.to(device)
|
||||
# ipex llm changes end
|
||||
|
||||
if scores.device.type == "xpu":
|
||||
import xe_addons
|
||||
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
|
||||
else:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue