fix minicpm V 2.6 repeat output (#11753)
This commit is contained in:
parent
7e917d6cfb
commit
93455aac09
1 changed files with 25 additions and 2 deletions
|
|
@ -15,6 +15,30 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
# todo
|
||||||
|
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
|
||||||
|
# if score < 0 then repetition penalty has to be
|
||||||
|
# multiplied to reduce the token probabilities
|
||||||
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
|
||||||
|
# ipex llm changes start: call scatter on CPU
|
||||||
|
device = scores.device
|
||||||
|
scores = scores.to('cpu')
|
||||||
|
input_ids = input_ids.to('cpu')
|
||||||
|
score = score.to('cpu')
|
||||||
|
scores.scatter_(1, input_ids, score)
|
||||||
|
scores = scores.to(device)
|
||||||
|
# ipex llm changes end
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def minicpmv_generate_wrapper(origin_generate):
|
def minicpmv_generate_wrapper(origin_generate):
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -30,8 +54,7 @@ def minicpmv_generate_wrapper(origin_generate):
|
||||||
decode_text=False,
|
decode_text=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if kwargs.get("repetition_penalty", None) is not None:
|
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
|
||||||
kwargs["repetition_penalty"] = 1
|
|
||||||
return origin_generate(
|
return origin_generate(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue