ipex-llm/python/llm/src/ipex_llm/transformers/models/minicpmv.py
2024-08-13 16:14:00 +08:00

80 lines
2.6 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "SiglipAttention")
def siglip_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
def minicpmv_generate_wrapper(origin_generate):
def generate(
*inputs,
**kwargs
):
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
return origin_generate(
*inputs,
**kwargs,
)
return generate