fix minicpm V 2.6 repeat output (#11753)
This commit is contained in:
		
							parent
							
								
									7e917d6cfb
								
							
						
					
					
						commit
						93455aac09
					
				
					 1 changed files with 25 additions and 2 deletions
				
			
		| 
						 | 
					@ -15,6 +15,30 @@
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# todo
 | 
				
			||||||
 | 
					def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
 | 
				
			||||||
 | 
					    score = torch.gather(scores, 1, input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if score < 0 then repetition penalty has to be
 | 
				
			||||||
 | 
					    # multiplied to reduce the token probabilities
 | 
				
			||||||
 | 
					    score = torch.where(score < 0, score * self.penalty, score / self.penalty)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex llm changes start: call scatter on CPU
 | 
				
			||||||
 | 
					    device = scores.device
 | 
				
			||||||
 | 
					    scores = scores.to('cpu')
 | 
				
			||||||
 | 
					    input_ids = input_ids.to('cpu')
 | 
				
			||||||
 | 
					    score = score.to('cpu')
 | 
				
			||||||
 | 
					    scores.scatter_(1, input_ids, score)
 | 
				
			||||||
 | 
					    scores = scores.to(device)
 | 
				
			||||||
 | 
					    # ipex llm changes end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return scores
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def minicpmv_generate_wrapper(origin_generate):
 | 
					def minicpmv_generate_wrapper(origin_generate):
 | 
				
			||||||
    def generate(
 | 
					    def generate(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -30,8 +54,7 @@ def minicpmv_generate_wrapper(origin_generate):
 | 
				
			||||||
        decode_text=False,
 | 
					        decode_text=False,
 | 
				
			||||||
        **kwargs
 | 
					        **kwargs
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        if kwargs.get("repetition_penalty", None) is not None:
 | 
					        RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
 | 
				
			||||||
            kwargs["repetition_penalty"] = 1
 | 
					 | 
				
			||||||
        return origin_generate(
 | 
					        return origin_generate(
 | 
				
			||||||
            self=self,
 | 
					            self=self,
 | 
				
			||||||
            input_ids=input_ids,
 | 
					            input_ids=input_ids,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue