rewrite minicpmv optimization (#11816)
This commit is contained in:
		
							parent
							
								
									447c8ed324
								
							
						
					
					
						commit
						4e178f0c5d
					
				
					 2 changed files with 140 additions and 1064 deletions
				
			
		| 
						 | 
				
			
			@ -747,17 +747,20 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
    if model.config.model_type == "llama":
 | 
			
		||||
        from ipex_llm.transformers.models.llama import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "minicpm":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "minicpmv":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
			
		||||
        model.vpm.apply(merge_qkv)
 | 
			
		||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
			
		||||
        if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
 | 
			
		||||
            model.llm.config.model_type = "minicpm"
 | 
			
		||||
        elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
			
		||||
            model.llm.config.model_type = "qwen2"
 | 
			
		||||
            _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
            model.llm.config.model_type = "minicpmv"
 | 
			
		||||
        elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
 | 
			
		||||
            model.llm.config.model_type = "llama"
 | 
			
		||||
            _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
            model.llm.config.model_type = "minicpmv"
 | 
			
		||||
        _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
        model.llm.config.model_type = "minicpmv"
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1699,31 +1702,16 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                        module.StableLmModel,
 | 
			
		||||
                        stablelm_model_forward
 | 
			
		||||
                        )
 | 
			
		||||
    elif model.config.model_type == 'minicpm':
 | 
			
		||||
    elif model.config.model_type == "minicpm":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        if version.parse(trans_version) >= version.parse("4.39.0"):
 | 
			
		||||
            from ipex_llm.transformers.models.minicpm import minicpm_attention_forward_4_39
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.MiniCPMAttention,
 | 
			
		||||
                            minicpm_attention_forward_4_39)
 | 
			
		||||
        else:
 | 
			
		||||
            from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.MiniCPMAttention,
 | 
			
		||||
                            minicpm_attention_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import minicpm_model_forward
 | 
			
		||||
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.MiniCPMMLP,
 | 
			
		||||
                        llama_mlp_forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.MiniCPMRMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward)
 | 
			
		||||
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.MiniCPMModel,
 | 
			
		||||
                        minicpm_model_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
 | 
			
		||||
        convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
 | 
			
		||||
        minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
 | 
			
		||||
    elif model.config.model_type == "minicpmv":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			@ -1734,18 +1722,14 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
 | 
			
		||||
            # MiniCPM-V 2
 | 
			
		||||
            model.llm.config.model_type = "minicpm"
 | 
			
		||||
            _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
            model.llm.config.model_type = "minicpmv"
 | 
			
		||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
			
		||||
        elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
			
		||||
            # MiniCPM-V 2.6
 | 
			
		||||
            model.llm.config.model_type = "qwen2"
 | 
			
		||||
            _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
            model.llm.config.model_type = "minicpmv"
 | 
			
		||||
        elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
 | 
			
		||||
            # MiniCPM-V 2.5
 | 
			
		||||
            model.llm.config.model_type = "llama"
 | 
			
		||||
            _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
            model.llm.config.model_type = "minicpmv"
 | 
			
		||||
        _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
        model.llm.config.model_type = "minicpmv"
 | 
			
		||||
 | 
			
		||||
        vpm_modeling_module_name = model.vpm.__class__.__module__
 | 
			
		||||
        vpm_module = importlib.import_module(vpm_modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
		Loading…
	
		Reference in a new issue