add basic glm-edge-v support (#12533)
This commit is contained in:
parent
3e0823d2ae
commit
ffce86d69f
2 changed files with 18 additions and 5 deletions
|
|
@ -1504,6 +1504,17 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model, module.GlmAttention, glm_attention_forward)
|
convert_forward(model, module.GlmAttention, glm_attention_forward)
|
||||||
glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
|
glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
|
||||||
convert_forward(model, module.GlmModel, glm_model_forward)
|
convert_forward(model, module.GlmModel, glm_model_forward)
|
||||||
|
|
||||||
|
if hasattr(model.model, "vision"):
|
||||||
|
# glm-edge-v series
|
||||||
|
vision_module_name = model.model.vision.__class__.__module__
|
||||||
|
vision_module = importlib.import_module(vision_module_name)
|
||||||
|
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
||||||
|
from ipex_llm.transformers.models.chatglm4v import vision_model_forward
|
||||||
|
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||||
|
convert_forward(model, vision_module.VisionModel, vision_model_forward)
|
||||||
|
convert_forward(model, SiglipAttention, siglip_attention_forward)
|
||||||
|
|
||||||
elif "mpt" in model.config.model_type:
|
elif "mpt" in model.config.model_type:
|
||||||
if model.config.architectures is not None:
|
if model.config.architectures is not None:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,6 @@ import torch
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.glm.modeling_glm import GlmAttention, GlmMLP
|
|
||||||
from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
|
from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
|
|
@ -46,11 +45,12 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
merge_qkv_base(module, GlmAttention)
|
merge_qkv_base(module, "GlmAttention")
|
||||||
|
merge_qkv_base(module, "SiglipAttention")
|
||||||
|
|
||||||
|
|
||||||
def split_mlp(module: torch.nn.Module):
|
def split_mlp(module: torch.nn.Module):
|
||||||
if isinstance(module, GlmMLP):
|
if module.__class__.__name__ == "GlmMLP":
|
||||||
gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0)
|
gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0)
|
||||||
|
|
||||||
gate_proj = torch.nn.Linear(0, 0, bias=False)
|
gate_proj = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
|
@ -157,6 +157,7 @@ def glm_model_forward_wrapper(origin_forward):
|
||||||
def glm_model_forward(
|
def glm_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
images: torch.Tensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
|
|
@ -166,7 +167,7 @@ def glm_model_forward_wrapper(origin_forward):
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**flash_attn_kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# ipex-llm changes start
|
# ipex-llm changes start
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv cache
|
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||||
|
|
@ -187,6 +188,7 @@ def glm_model_forward_wrapper(origin_forward):
|
||||||
return origin_forward(
|
return origin_forward(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
images=images,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
|
@ -196,7 +198,7 @@ def glm_model_forward_wrapper(origin_forward):
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**flash_attn_kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return glm_model_forward
|
return glm_model_forward
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue