optimize internvl2 vision model's attention (#12198)
This commit is contained in:
parent
f8d1adc573
commit
d5344587ab
2 changed files with 29 additions and 1 deletions
|
|
@ -1580,8 +1580,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
model.batch_chat = MethodType(internvl_batch_chat, model)
|
model.batch_chat = MethodType(internvl_batch_chat, model)
|
||||||
if model.vision_model.__class__.__name__ == "InternVisionModel":
|
if model.vision_model.__class__.__name__ == "InternVisionModel":
|
||||||
from ipex_llm.transformers.models.internvl import _get_pos_embed
|
from ipex_llm.transformers.models.internvl import _get_pos_embed
|
||||||
vision_embedding = model.vision_model.embeddings
|
from ipex_llm.transformers.models.internvl import intern_attention_forward
|
||||||
|
vision_model = model.vision_model
|
||||||
|
vision_embedding = vision_model.embeddings
|
||||||
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
|
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
|
||||||
|
vision_module = importlib.import_module(vision_model.__class__.__module__)
|
||||||
|
convert_forward(vision_model, vision_module.InternAttention, intern_attention_forward)
|
||||||
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
|
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
|
||||||
elif model.config.model_type == "qwen":
|
elif model.config.model_type == "qwen":
|
||||||
if hasattr(model.config, "visual"):
|
if hasattr(model.config, "visual"):
|
||||||
|
|
|
||||||
|
|
@ -163,3 +163,27 @@ def internvl_batch_chat(self, tokenizer, pixel_values, questions, generation_con
|
||||||
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
|
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
|
||||||
responses = [response.split(template.sep)[0].strip() for response in responses]
|
responses = [response.split(template.sep)[0].strip() for response in responses]
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
if self.qk_normalization:
|
||||||
|
B_, H_, N_, D_ = q.shape
|
||||||
|
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||||
|
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||||
|
|
||||||
|
if x.device.type == "xpu":
|
||||||
|
import xe_addons
|
||||||
|
x = xe_addons.sdp_non_causal(q.contiguous(), k.contiguous(), v.contiguous(), None)
|
||||||
|
else:
|
||||||
|
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue