fix minicpm-v 2.5 (#11780)
This commit is contained in:
		
							parent
							
								
									ec184af243
								
							
						
					
					
						commit
						a184b120c9
					
				
					 2 changed files with 21 additions and 31 deletions
				
			
		| 
						 | 
					@ -1755,19 +1755,29 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
                        module.MiniCPMModel,
 | 
					                        module.MiniCPMModel,
 | 
				
			||||||
                        minicpm_model_forward)
 | 
					                        minicpm_model_forward)
 | 
				
			||||||
    elif model.config.model_type == "minicpmv":
 | 
					    elif model.config.model_type == "minicpmv":
 | 
				
			||||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
					 | 
				
			||||||
            model.llm.config.model_type = "qwen2"
 | 
					 | 
				
			||||||
            _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
					 | 
				
			||||||
            model.llm.config.model_type = "minicpmv"
 | 
					 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
        from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
 | 
					        from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
 | 
				
			||||||
        minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
 | 
					        minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
 | 
				
			||||||
        model.generate = MethodType(minicpmv_generate, model)
 | 
					        model.generate = MethodType(minicpmv_generate, model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modeling_module_name = model.vpm.__class__.__module__
 | 
					        if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					            # MiniCPM-V 2.6
 | 
				
			||||||
 | 
					            model.llm.config.model_type = "qwen2"
 | 
				
			||||||
 | 
					            _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
				
			||||||
 | 
					            model.llm.config.model_type = "minicpmv"
 | 
				
			||||||
 | 
					        elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
 | 
				
			||||||
 | 
					            # MiniCPM-V 2.5
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vpm_modeling_module_name = model.vpm.__class__.__module__
 | 
				
			||||||
 | 
					        vpm_module = importlib.import_module(vpm_modeling_module_name)
 | 
				
			||||||
 | 
					        if model.vpm.config.model_type == "siglip":
 | 
				
			||||||
 | 
					            # MiniCPM-V 2.6
 | 
				
			||||||
            from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
					            from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
				
			||||||
        convert_forward(model, module.SiglipAttention, siglip_attention_forward)
 | 
					            convert_forward(model, vpm_module.SiglipAttention, siglip_attention_forward)
 | 
				
			||||||
 | 
					        elif model.vpm.config.model_type == "idefics2":
 | 
				
			||||||
 | 
					            # MiniCPM-V 2.5
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -69,32 +69,12 @@ def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def minicpmv_generate_wrapper(origin_generate):
 | 
					def minicpmv_generate_wrapper(origin_generate):
 | 
				
			||||||
    def generate(
 | 
					    def generate(
 | 
				
			||||||
        self,
 | 
					        *inputs,
 | 
				
			||||||
        input_ids=None,
 | 
					 | 
				
			||||||
        pixel_values=None,
 | 
					 | 
				
			||||||
        tgt_sizes=None,
 | 
					 | 
				
			||||||
        image_bound=None,
 | 
					 | 
				
			||||||
        attention_mask=None,
 | 
					 | 
				
			||||||
        tokenizer=None,
 | 
					 | 
				
			||||||
        vision_hidden_states=None,
 | 
					 | 
				
			||||||
        return_vision_hidden_states=False,
 | 
					 | 
				
			||||||
        stream=False,
 | 
					 | 
				
			||||||
        decode_text=False,
 | 
					 | 
				
			||||||
        **kwargs
 | 
					        **kwargs
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
 | 
					        RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
 | 
				
			||||||
        return origin_generate(
 | 
					        return origin_generate(
 | 
				
			||||||
            self=self,
 | 
					            *inputs,
 | 
				
			||||||
            input_ids=input_ids,
 | 
					            **kwargs,
 | 
				
			||||||
            pixel_values=pixel_values,
 | 
					 | 
				
			||||||
            tgt_sizes=tgt_sizes,
 | 
					 | 
				
			||||||
            image_bound=image_bound,
 | 
					 | 
				
			||||||
            attention_mask=attention_mask,
 | 
					 | 
				
			||||||
            tokenizer=tokenizer,
 | 
					 | 
				
			||||||
            vision_hidden_states=vision_hidden_states,
 | 
					 | 
				
			||||||
            return_vision_hidden_states=return_vision_hidden_states,
 | 
					 | 
				
			||||||
            stream=stream,
 | 
					 | 
				
			||||||
            decode_text=decode_text,
 | 
					 | 
				
			||||||
            **kwargs
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    return generate
 | 
					    return generate
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue