add support and optimization for minicpmo audio part (#12716)
This commit is contained in:
parent
53aae24616
commit
bda87c21eb
2 changed files with 118 additions and 9 deletions
|
|
@ -1030,9 +1030,9 @@ def _optimize_pre(model, qtype=None):
|
|||
model.llm.config.model_type = "minicpmv"
|
||||
elif model.config.model_type == "minicpmo":
|
||||
# vpm opt
|
||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||
model.vpm.apply(merge_qkv)
|
||||
|
||||
if hasattr(model, "vpm"):
|
||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||
model.vpm.apply(merge_qkv)
|
||||
# llm opt
|
||||
model.llm.config.model_type = "qwen2"
|
||||
_optimize_pre(model.llm, qtype=qtype)
|
||||
|
|
@ -1955,12 +1955,18 @@ def _optimize_post(model):
|
|||
model.chat = MethodType(minicpmv_chat, model)
|
||||
elif model.config.model_type == "minicpmo":
|
||||
# vpm opt
|
||||
vpm_modeling_module_name = model.vpm.__class__.__module__
|
||||
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||
|
||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||
|
||||
if hasattr(model, "vpm"):
|
||||
vpm_modeling_module_name = model.vpm.__class__.__module__
|
||||
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||
# apm opt
|
||||
if hasattr(model, "apm"):
|
||||
apm_modeling_module_name = model.apm.__class__.__module__
|
||||
apm_module = importlib.import_module(apm_modeling_module_name)
|
||||
from transformers.models.whisper.modeling_whisper import WhisperSdpaAttention
|
||||
from ipex_llm.transformers.models.whisper import whisper_attention_forward
|
||||
convert_forward(model.apm, WhisperSdpaAttention, whisper_attention_forward)
|
||||
# llm opt
|
||||
model.llm.config.model_type = "qwen2"
|
||||
_optimize_post(model.llm)
|
||||
|
|
|
|||
103
python/llm/src/ipex_llm/transformers/models/whisper.py
Normal file
103
python/llm/src/ipex_llm/transformers/models/whisper.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from transformers.cache_utils import EncoderDecoderCache
|
||||
|
||||
from ipex_llm.transformers.utils import invalidInputError
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
|
||||
|
||||
def whisper_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
invalidInputError(not output_attentions and layer_head_mask is None,
|
||||
"`output_attentions` and `layer_head_mask` are not supported")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
# IPEX-LLM OPT: sdpa
|
||||
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
attention_mask,
|
||||
is_causal
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
Loading…
Reference in a new issue