fix qwen2 vl again (#12174)
This commit is contained in:
parent
aef1f671bd
commit
535bee5381
2 changed files with 6 additions and 0 deletions
|
|
@ -1691,11 +1691,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
|
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
|
||||||
|
from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_get_dtype
|
||||||
from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward
|
from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward
|
||||||
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward
|
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward
|
||||||
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward
|
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward
|
||||||
convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward)
|
convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
|
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
|
||||||
|
model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
|
||||||
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
|
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
|
||||||
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
|
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
|
||||||
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
|
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
|
||||||
|
|
|
||||||
|
|
@ -175,6 +175,10 @@ def qwen2_vl_model_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_vision_get_dtype(self) -> torch.dtype:
|
||||||
|
return self.patch_embed.proj.weight.dtype
|
||||||
|
|
||||||
|
|
||||||
def qwen2_vision_attention_forward(
|
def qwen2_vision_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue