From d5344587ab2212fb01b14c323f7177e90fa4ce0c Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 15 Oct 2024 10:51:00 +0800 Subject: [PATCH] optimize internvl2 vision model's attention (#12198) --- .../llm/src/ipex_llm/transformers/convert.py | 6 ++++- .../ipex_llm/transformers/models/internvl.py | 24 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 0a584499..c3893669 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1580,8 +1580,12 @@ def _optimize_post(model, lightweight_bmm=False): model.batch_chat = MethodType(internvl_batch_chat, model) if model.vision_model.__class__.__name__ == "InternVisionModel": from ipex_llm.transformers.models.internvl import _get_pos_embed - vision_embedding = model.vision_model.embeddings + from ipex_llm.transformers.models.internvl import intern_attention_forward + vision_model = model.vision_model + vision_embedding = vision_model.embeddings vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding) + vision_module = importlib.import_module(vision_model.__class__.__module__) + convert_forward(vision_model, vision_module.InternAttention, intern_attention_forward) _optimize_post(model.language_model, lightweight_bmm=lightweight_bmm) elif model.config.model_type == "qwen": if hasattr(model.config, "visual"): diff --git a/python/llm/src/ipex_llm/transformers/models/internvl.py b/python/llm/src/ipex_llm/transformers/models/internvl.py index 633f337a..1cecdd25 100644 --- a/python/llm/src/ipex_llm/transformers/models/internvl.py +++ b/python/llm/src/ipex_llm/transformers/models/internvl.py @@ -163,3 +163,27 @@ def internvl_batch_chat(self, tokenizer, pixel_values, questions, generation_con responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) responses = [response.split(template.sep)[0].strip() for response in responses] return responses + + +def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + + if x.device.type == "xpu": + import xe_addons + x = xe_addons.sdp_non_causal(q.contiguous(), k.contiguous(), v.contiguous(), None) + else: + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x