diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index bdf9aa3a..ebde9407 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -19,23 +19,14 @@ import torch from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor -# todo def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): - score = torch.gather(scores, 1, input_ids) - - # if score < 0 then repetition penalty has to be - # multiplied to reduce the token probabilities - score = torch.where(score < 0, score * self.penalty, score / self.penalty) - - # ipex llm changes start: call scatter on CPU - device = scores.device - scores = scores.to('cpu') - input_ids = input_ids.to('cpu') - score = score.to('cpu') - scores.scatter_(1, input_ids, score) - scores = scores.to(device) - # ipex llm changes end - + if scores.device.type == "xpu": + import xe_addons + xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty) + else: + score = torch.gather(scores, 1, input_ids) + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + scores.scatter_(1, input_ids, score) return scores