optimize internvl2 vision model's attention (#12198)

This commit is contained in:
Yishuo Wang 2024-10-15 10:51:00 +08:00 committed by GitHub
parent f8d1adc573
commit d5344587ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 29 additions and 1 deletions

View file

@ -1580,8 +1580,12 @@ def _optimize_post(model, lightweight_bmm=False):
model.batch_chat = MethodType(internvl_batch_chat, model) model.batch_chat = MethodType(internvl_batch_chat, model)
if model.vision_model.__class__.__name__ == "InternVisionModel": if model.vision_model.__class__.__name__ == "InternVisionModel":
from ipex_llm.transformers.models.internvl import _get_pos_embed from ipex_llm.transformers.models.internvl import _get_pos_embed
vision_embedding = model.vision_model.embeddings from ipex_llm.transformers.models.internvl import intern_attention_forward
vision_model = model.vision_model
vision_embedding = vision_model.embeddings
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding) vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
vision_module = importlib.import_module(vision_model.__class__.__module__)
convert_forward(vision_model, vision_module.InternAttention, intern_attention_forward)
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm) _optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
elif model.config.model_type == "qwen": elif model.config.model_type == "qwen":
if hasattr(model.config, "visual"): if hasattr(model.config, "visual"):

View file

@ -163,3 +163,27 @@ def internvl_batch_chat(self, tokenizer, pixel_values, questions, generation_con
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
responses = [response.split(template.sep)[0].strip() for response in responses] responses = [response.split(template.sep)[0].strip() for response in responses]
return responses return responses
def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
if x.device.type == "xpu":
import xe_addons
x = xe_addons.sdp_non_causal(q.contiguous(), k.contiguous(), v.contiguous(), None)
else:
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x