ipex-llm/python/llm/src/ipex_llm/transformers/models/minicpmv.py

282 lines
9.6 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/modeling_minicpmv.py
# which is licensed under Apache License 2.0:
#
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
#
import math
import torch
from threading import Thread
from typing import Optional, List
from torch.nn.functional import linear
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from transformers import AutoProcessor, TextIteratorStreamer
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
# MiniCPM-V-2_5 and MiniCPM-V-2_6
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "SiglipSdpaAttention")
merge_qkv_base(module, "Idefics2VisionAttention")
# MiniCPM-V-2_5 and MiniCPM-V-2_6
def siglip_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)
from ipex_llm.transformers.utils import get_xpu_device_name
if (
self.head_dim == 72
and get_xpu_device_name(query_states.device) == "arc" and
query_states.dtype in [torch.float, torch.half]
):
n_heads, kv_length = query_states.size(1), key_states.size(2)
from ipex_llm.transformers.models.common import prepare_mask
attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length,
False, query_states.dtype, query_states.device)
import xe_addons
attn_weights = None
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
value_states, attention_mask)
else:
query_states, key_states, value_states = padding_qkv_hd(
query_states, key_states, value_states,
72, 80
)
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states.contiguous(), value_states.contiguous(),
attention_mask, False, 1 / math.sqrt(self.head_dim)
)
attn_output = attn_output[:, :, :, :self.head_dim]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
# MiniCPM-V-2_6
def _in_projection_packed(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
b: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
E = q.size(-1)
if k is v:
if q is k:
# self-attention
proj = linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for
# better memory coalescing and keeping same order as chunk()
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2)
proj = proj.contiguous()
return proj[0], proj[1], proj[2]
else:
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
q_proj = linear(q, w_q, b_q)
kv_proj = linear(k, w_kv, b_kv)
# reshape to 2, E and not E, 2 is deliberate for
# better memory coalescing and keeping same order as chunk()
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2)
kv_proj = kv_proj.contiguous()
return (q_proj, kv_proj[0], kv_proj[1])
else:
w_q, w_k, w_v = w.chunk(3)
# ipex-llm changes start: add contiguous to workaround a ipex bug
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
w_q = w_q.contiguous()
w_k = w_k.contiguous()
w_v = w_v.contiguous()
# ipex-llm changes end
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
# for minicpm-v-2_6 benchmarking purposes
def minicpmv_decode_stream_wrapper(origin_decode_stream):
def minicpv_decode_stream(
self,
inputs_embeds,
tokenizer,
**kwargs
):
streamer = kwargs.get('streamer', None)
if streamer is not None:
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
generation_kwargs = {
'inputs_embeds': inputs_embeds,
'pad_token_id': 0,
'eos_token_id': terminators,
}
generation_kwargs.update(kwargs)
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
thread.start()
return streamer
else:
return origin_decode_stream(
self=self,
inputs_embeds=inputs_embeds,
tokenizer=tokenizer,
**kwargs
)
return minicpv_decode_stream
# MiniCPM-V-2
# modified from timm.models.vision_transformer.Attention.forward
def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, q_len, hidden_size = x.size()
qkv = self.qkv(x)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
attn_weights = attention_softmax(attn_weights)
attn_weights = self.attn_drop(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
attn_output = self.proj(attn_output)
attn_output = self.proj_drop(attn_output)
return attn_output
# MiniCPM-V-2_5
def minicpmv_chat_wrapper(origin_chat):
def minicpmv_chat(
self,
image,
msgs,
tokenizer,
processor=None,
vision_hidden_states=None,
max_new_tokens=1024,
sampling=True,
max_inp_length=2048,
system_prompt='',
stream=False,
**kwargs
):
if processor is None:
if getattr(self, "processor", None) is None:
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path,
trust_remote_code=True)
processor = self.processor
return origin_chat(
self=self,
image=image,
msgs=msgs,
tokenizer=tokenizer,
processor=processor,
vision_hidden_states=vision_hidden_states,
max_new_tokens=max_new_tokens,
sampling=sampling,
max_inp_length=max_inp_length,
system_prompt=system_prompt,
stream=stream,
**kwargs
)
return minicpmv_chat
# MiniCPM-V-2
def minicpmv_get_vision_embedding(self, pixel_values):
res = []
dtype = self.dtype
def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
H, W = pixel_value.shape[-2:]
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
return resampler(vision_embedding, target_size)
for pixel_value in pixel_values:
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
res.append(result)
return torch.vstack(res)
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
def minicpmv_generate_wrapper(origin_generate):
def generate(
*inputs,
**kwargs
):
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
# for minicpm-v-2_6 benchmarking purposes
stream = kwargs.get("stream", False)
if isinstance(stream, TextIteratorStreamer):
kwargs.update({'streamer': stream})
return origin_generate(
*inputs,
**kwargs,
)
return generate