fix and optimize minicpm v 2 (#11799)
This commit is contained in:
		
							parent
							
								
									d8d887edd2
								
							
						
					
					
						commit
						9a93808fc5
					
				
					 2 changed files with 45 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -1726,6 +1726,11 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
 | 
			
		||||
        model.generate = MethodType(minicpmv_generate, model)
 | 
			
		||||
 | 
			
		||||
        if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
 | 
			
		||||
            # MiniCPM-V 2
 | 
			
		||||
            model.llm.config.model_type = "minicpm"
 | 
			
		||||
            _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
            model.llm.config.model_type = "minicpmv"
 | 
			
		||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
			
		||||
            # MiniCPM-V 2.6
 | 
			
		||||
            model.llm.config.model_type = "qwen2"
 | 
			
		||||
| 
						 | 
				
			
			@ -1739,7 +1744,11 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
 | 
			
		||||
        vpm_modeling_module_name = model.vpm.__class__.__module__
 | 
			
		||||
        vpm_module = importlib.import_module(vpm_modeling_module_name)
 | 
			
		||||
        if model.vpm.config.model_type == "siglip":
 | 
			
		||||
        if not hasattr(model.vpm, "config"):
 | 
			
		||||
            # MiniCPM-V 2
 | 
			
		||||
            from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
 | 
			
		||||
            model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
 | 
			
		||||
        elif model.vpm.config.model_type == "siglip":
 | 
			
		||||
            # MiniCPM-V 2.6
 | 
			
		||||
            from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
			
		||||
            convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@
 | 
			
		|||
#
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
| 
						 | 
				
			
			@ -22,11 +23,13 @@ from transformers import AutoProcessor
 | 
			
		|||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# MiniCPM-V-2_5 and MiniCPM-V-2_6
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    merge_qkv_base(module, "SiglipAttention")
 | 
			
		||||
    merge_qkv_base(module, "Idefics2VisionAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# MiniCPM-V-2_5 and MiniCPM-V-2_6
 | 
			
		||||
def siglip_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -58,17 +61,7 @@ def siglip_attention_forward(
 | 
			
		|||
    return attn_output, attn_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
 | 
			
		||||
    if scores.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
 | 
			
		||||
    else:
 | 
			
		||||
        score = torch.gather(scores, 1, input_ids)
 | 
			
		||||
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)
 | 
			
		||||
        scores.scatter_(1, input_ids, score)
 | 
			
		||||
    return scores
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# MiniCPM-V-2_5
 | 
			
		||||
def minicpmv_chat_wrapper(origin_chat):
 | 
			
		||||
    def minicpmv_chat(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -106,6 +99,37 @@ def minicpmv_chat_wrapper(origin_chat):
 | 
			
		|||
    return minicpmv_chat
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# MiniCPM-V-2
 | 
			
		||||
def minicpmv_get_vision_embedding(self, pixel_values):
 | 
			
		||||
    res = []
 | 
			
		||||
    dtype = self.dtype
 | 
			
		||||
 | 
			
		||||
    def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
 | 
			
		||||
        H, W = pixel_value.shape[-2:]
 | 
			
		||||
        target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
 | 
			
		||||
        vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
 | 
			
		||||
 | 
			
		||||
        if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
 | 
			
		||||
            vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
 | 
			
		||||
        return resampler(vision_embedding, target_size)
 | 
			
		||||
 | 
			
		||||
    for pixel_value in pixel_values:
 | 
			
		||||
        result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
 | 
			
		||||
        res.append(result)
 | 
			
		||||
    return torch.vstack(res)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
 | 
			
		||||
    if scores.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
 | 
			
		||||
    else:
 | 
			
		||||
        score = torch.gather(scores, 1, input_ids)
 | 
			
		||||
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)
 | 
			
		||||
        scores.scatter_(1, input_ids, score)
 | 
			
		||||
    return scores
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def minicpmv_generate_wrapper(origin_generate):
 | 
			
		||||
    def generate(
 | 
			
		||||
        *inputs,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue