support Megrez-3B-Omni (#12582)
This commit is contained in:
		
							parent
							
								
									4e7e988f70
								
							
						
					
					
						commit
						3eeb02f1be
					
				
					 2 changed files with 32 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -1055,6 +1055,12 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
            model.llm.config.model_type = "minicpm"
 | 
			
		||||
        _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
        model.llm.config.model_type = "minicpmv"
 | 
			
		||||
    elif model.config.model_type == "megrezo":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
			
		||||
        model.vision.apply(merge_qkv)
 | 
			
		||||
        model.llm.config.model_type = "llama"
 | 
			
		||||
        _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
        model.llm.config.model_type = "megrezo"
 | 
			
		||||
    elif model.config.model_type == "chatglm":
 | 
			
		||||
        if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
 | 
			
		||||
            # chatglm2 and chatglm3
 | 
			
		||||
| 
						 | 
				
			
			@ -2202,5 +2208,29 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
 | 
			
		||||
            minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
 | 
			
		||||
            model.chat = MethodType(minicpmv_chat, model)
 | 
			
		||||
    elif model.config.model_type == "megrezo":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
 | 
			
		||||
        minicpmv_generate = minicpmv_generate_wrapper(module.MegrezO.generate)
 | 
			
		||||
        model.generate = MethodType(minicpmv_generate, model)
 | 
			
		||||
 | 
			
		||||
        # vision
 | 
			
		||||
        vpm_modeling_module_name = model.vision.vpm.__class__.__module__
 | 
			
		||||
        vpm_module = importlib.import_module(vpm_modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
			
		||||
        convert_forward(model.vision.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
 | 
			
		||||
 | 
			
		||||
        # resampler
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import _in_projection_packed
 | 
			
		||||
        resampler_module_name = model.vision.resampler.__class__.__module__
 | 
			
		||||
        resampler_module = importlib.import_module(resampler_module_name)
 | 
			
		||||
        resampler_module._in_projection_packed = _in_projection_packed
 | 
			
		||||
 | 
			
		||||
        # llm
 | 
			
		||||
        model.llm.config.model_type = "llama"
 | 
			
		||||
        model.llm.config.rope_scaling = {"rope_type": "default"}
 | 
			
		||||
        _optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
        model.llm.config.model_type = "megrezo"
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -198,8 +198,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
 | 
			
		|||
        elif seq_length != kv_length and seq_length <= 32:
 | 
			
		||||
            mask = None
 | 
			
		||||
        else:
 | 
			
		||||
            mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
 | 
			
		||||
                               dtype=dtype, device=device)
 | 
			
		||||
            mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device)
 | 
			
		||||
            mask[:, :, kv_length:padding_kv_length] = torch.finfo(dtype).min
 | 
			
		||||
            mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
 | 
			
		||||
    else:
 | 
			
		||||
        if seq_length != kv_length and seq_length <= 32:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue