add basic glm-edge-v support (#12533)
This commit is contained in:
		
							parent
							
								
									3e0823d2ae
								
							
						
					
					
						commit
						ffce86d69f
					
				
					 2 changed files with 18 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -1504,6 +1504,17 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model, module.GlmAttention, glm_attention_forward)
 | 
			
		||||
        glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
 | 
			
		||||
        convert_forward(model, module.GlmModel, glm_model_forward)
 | 
			
		||||
 | 
			
		||||
        if hasattr(model.model, "vision"):
 | 
			
		||||
            # glm-edge-v series
 | 
			
		||||
            vision_module_name = model.model.vision.__class__.__module__
 | 
			
		||||
            vision_module = importlib.import_module(vision_module_name)
 | 
			
		||||
            from transformers.models.siglip.modeling_siglip import SiglipAttention
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm4v import vision_model_forward
 | 
			
		||||
            from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
			
		||||
            convert_forward(model, vision_module.VisionModel, vision_model_forward)
 | 
			
		||||
            convert_forward(model, SiglipAttention, siglip_attention_forward)
 | 
			
		||||
 | 
			
		||||
    elif "mpt" in model.config.model_type:
 | 
			
		||||
        if model.config.architectures is not None:
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,7 +37,6 @@ import torch
 | 
			
		|||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
from transformers.models.glm.modeling_glm import GlmAttention, GlmMLP
 | 
			
		||||
from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
| 
						 | 
				
			
			@ -46,11 +45,12 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    merge_qkv_base(module, GlmAttention)
 | 
			
		||||
    merge_qkv_base(module, "GlmAttention")
 | 
			
		||||
    merge_qkv_base(module, "SiglipAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_mlp(module: torch.nn.Module):
 | 
			
		||||
    if isinstance(module, GlmMLP):
 | 
			
		||||
    if module.__class__.__name__ == "GlmMLP":
 | 
			
		||||
        gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0)
 | 
			
		||||
 | 
			
		||||
        gate_proj = torch.nn.Linear(0, 0, bias=False)
 | 
			
		||||
| 
						 | 
				
			
			@ -157,6 +157,7 @@ def glm_model_forward_wrapper(origin_forward):
 | 
			
		|||
    def glm_model_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        images: torch.Tensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -166,7 +167,7 @@ def glm_model_forward_wrapper(origin_forward):
 | 
			
		|||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
        cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
        **flash_attn_kwargs,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        # ipex-llm changes start
 | 
			
		||||
        # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
			
		||||
| 
						 | 
				
			
			@ -187,6 +188,7 @@ def glm_model_forward_wrapper(origin_forward):
 | 
			
		|||
        return origin_forward(
 | 
			
		||||
            self=self,
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            images=images,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
| 
						 | 
				
			
			@ -196,7 +198,7 @@ def glm_model_forward_wrapper(origin_forward):
 | 
			
		|||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
            cache_position=cache_position,
 | 
			
		||||
            **flash_attn_kwargs,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return glm_model_forward
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue