282 lines
9.6 KiB
Python
282 lines
9.6 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Some parts of this file is adapted from
|
|
# https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/modeling_minicpmv.py
|
|
# which is licensed under Apache License 2.0:
|
|
#
|
|
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
|
|
#
|
|
|
|
|
|
import math
|
|
import torch
|
|
from threading import Thread
|
|
from typing import Optional, List
|
|
from torch.nn.functional import linear
|
|
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
|
|
from ipex_llm.transformers.models.common import attention_softmax
|
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
|
from transformers import AutoProcessor, TextIteratorStreamer
|
|
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
|
|
|
|
|
# MiniCPM-V-2_5 and MiniCPM-V-2_6
|
|
def merge_qkv(module: torch.nn.Module):
|
|
merge_qkv_base(module, "SiglipAttention")
|
|
merge_qkv_base(module, "SiglipSdpaAttention")
|
|
merge_qkv_base(module, "Idefics2VisionAttention")
|
|
|
|
|
|
# MiniCPM-V-2_5 and MiniCPM-V-2_6
|
|
def siglip_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
):
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
qkv = self.qkv_proj(hidden_states)
|
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
|
qkv = qkv.transpose(1, 2)
|
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
|
|
|
from ipex_llm.transformers.utils import get_xpu_device_name
|
|
if (
|
|
self.head_dim == 72
|
|
and get_xpu_device_name(query_states.device) == "arc" and
|
|
query_states.dtype in [torch.float, torch.half]
|
|
):
|
|
n_heads, kv_length = query_states.size(1), key_states.size(2)
|
|
from ipex_llm.transformers.models.common import prepare_mask
|
|
attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length,
|
|
False, query_states.dtype, query_states.device)
|
|
import xe_addons
|
|
attn_weights = None
|
|
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
|
value_states, attention_mask)
|
|
else:
|
|
query_states, key_states, value_states = padding_qkv_hd(
|
|
query_states, key_states, value_states,
|
|
72, 80
|
|
)
|
|
|
|
attn_weights = None
|
|
attn_output = scaled_dot_product_attention(
|
|
query_states, key_states.contiguous(), value_states.contiguous(),
|
|
attention_mask, False, 1 / math.sqrt(self.head_dim)
|
|
)
|
|
|
|
attn_output = attn_output[:, :, :, :self.head_dim]
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
# MiniCPM-V-2_6
|
|
def _in_projection_packed(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
w: torch.Tensor,
|
|
b: Optional[torch.Tensor] = None,
|
|
) -> List[torch.Tensor]:
|
|
E = q.size(-1)
|
|
if k is v:
|
|
if q is k:
|
|
# self-attention
|
|
proj = linear(q, w, b)
|
|
# reshape to 3, E and not E, 3 is deliberate for
|
|
# better memory coalescing and keeping same order as chunk()
|
|
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
|
proj = proj.contiguous()
|
|
return proj[0], proj[1], proj[2]
|
|
else:
|
|
# encoder-decoder attention
|
|
w_q, w_kv = w.split([E, E * 2])
|
|
if b is None:
|
|
b_q = b_kv = None
|
|
else:
|
|
b_q, b_kv = b.split([E, E * 2])
|
|
q_proj = linear(q, w_q, b_q)
|
|
kv_proj = linear(k, w_kv, b_kv)
|
|
# reshape to 2, E and not E, 2 is deliberate for
|
|
# better memory coalescing and keeping same order as chunk()
|
|
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
|
kv_proj = kv_proj.contiguous()
|
|
return (q_proj, kv_proj[0], kv_proj[1])
|
|
else:
|
|
w_q, w_k, w_v = w.chunk(3)
|
|
# ipex-llm changes start: add contiguous to workaround a ipex bug
|
|
q = q.contiguous()
|
|
k = k.contiguous()
|
|
v = v.contiguous()
|
|
w_q = w_q.contiguous()
|
|
w_k = w_k.contiguous()
|
|
w_v = w_v.contiguous()
|
|
# ipex-llm changes end
|
|
if b is None:
|
|
b_q = b_k = b_v = None
|
|
else:
|
|
b_q, b_k, b_v = b.chunk(3)
|
|
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
|
|
|
|
|
# for minicpm-v-2_6 benchmarking purposes
|
|
def minicpmv_decode_stream_wrapper(origin_decode_stream):
|
|
def minicpv_decode_stream(
|
|
self,
|
|
inputs_embeds,
|
|
tokenizer,
|
|
**kwargs
|
|
):
|
|
streamer = kwargs.get('streamer', None)
|
|
if streamer is not None:
|
|
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
|
generation_kwargs = {
|
|
'inputs_embeds': inputs_embeds,
|
|
'pad_token_id': 0,
|
|
'eos_token_id': terminators,
|
|
}
|
|
generation_kwargs.update(kwargs)
|
|
|
|
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
|
|
thread.start()
|
|
|
|
return streamer
|
|
else:
|
|
return origin_decode_stream(
|
|
self=self,
|
|
inputs_embeds=inputs_embeds,
|
|
tokenizer=tokenizer,
|
|
**kwargs
|
|
)
|
|
return minicpv_decode_stream
|
|
|
|
|
|
# MiniCPM-V-2
|
|
# modified from timm.models.vision_transformer.Attention.forward
|
|
def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
bsz, q_len, hidden_size = x.size()
|
|
|
|
qkv = self.qkv(x)
|
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
|
qkv = qkv.transpose(1, 2)
|
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
|
|
|
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
|
attn_weights = attention_softmax(attn_weights)
|
|
attn_weights = self.attn_drop(attn_weights)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
|
|
|
attn_output = self.proj(attn_output)
|
|
attn_output = self.proj_drop(attn_output)
|
|
return attn_output
|
|
|
|
|
|
# MiniCPM-V-2_5
|
|
def minicpmv_chat_wrapper(origin_chat):
|
|
def minicpmv_chat(
|
|
self,
|
|
image,
|
|
msgs,
|
|
tokenizer,
|
|
processor=None,
|
|
vision_hidden_states=None,
|
|
max_new_tokens=1024,
|
|
sampling=True,
|
|
max_inp_length=2048,
|
|
system_prompt='',
|
|
stream=False,
|
|
**kwargs
|
|
):
|
|
if processor is None:
|
|
if getattr(self, "processor", None) is None:
|
|
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path,
|
|
trust_remote_code=True)
|
|
processor = self.processor
|
|
return origin_chat(
|
|
self=self,
|
|
image=image,
|
|
msgs=msgs,
|
|
tokenizer=tokenizer,
|
|
processor=processor,
|
|
vision_hidden_states=vision_hidden_states,
|
|
max_new_tokens=max_new_tokens,
|
|
sampling=sampling,
|
|
max_inp_length=max_inp_length,
|
|
system_prompt=system_prompt,
|
|
stream=stream,
|
|
**kwargs
|
|
)
|
|
return minicpmv_chat
|
|
|
|
|
|
# MiniCPM-V-2
|
|
def minicpmv_get_vision_embedding(self, pixel_values):
|
|
res = []
|
|
dtype = self.dtype
|
|
|
|
def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
|
|
H, W = pixel_value.shape[-2:]
|
|
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
|
|
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
|
|
|
|
if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
|
|
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
|
|
return resampler(vision_embedding, target_size)
|
|
|
|
for pixel_value in pixel_values:
|
|
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
|
|
res.append(result)
|
|
return torch.vstack(res)
|
|
|
|
|
|
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
|
if scores.device.type == "xpu":
|
|
import xe_addons
|
|
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
|
|
else:
|
|
score = torch.gather(scores, 1, input_ids)
|
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
|
scores.scatter_(1, input_ids, score)
|
|
return scores
|
|
|
|
|
|
def minicpmv_generate_wrapper(origin_generate):
|
|
def generate(
|
|
*inputs,
|
|
**kwargs
|
|
):
|
|
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
|
|
|
|
# for minicpm-v-2_6 benchmarking purposes
|
|
stream = kwargs.get("stream", False)
|
|
if isinstance(stream, TextIteratorStreamer):
|
|
kwargs.update({'streamer': stream})
|
|
|
|
return origin_generate(
|
|
*inputs,
|
|
**kwargs,
|
|
)
|
|
return generate
|