add support and optimization for minicpmo audio part (#12716)
This commit is contained in:
		
							parent
							
								
									53aae24616
								
							
						
					
					
						commit
						bda87c21eb
					
				
					 2 changed files with 118 additions and 9 deletions
				
			
		| 
						 | 
				
			
			@ -1030,9 +1030,9 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
        model.llm.config.model_type = "minicpmv"
 | 
			
		||||
    elif model.config.model_type == "minicpmo":
 | 
			
		||||
        # vpm opt
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
			
		||||
        model.vpm.apply(merge_qkv)
 | 
			
		||||
 | 
			
		||||
        if hasattr(model, "vpm"):
 | 
			
		||||
            from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
			
		||||
            model.vpm.apply(merge_qkv)
 | 
			
		||||
        # llm opt
 | 
			
		||||
        model.llm.config.model_type = "qwen2"
 | 
			
		||||
        _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
| 
						 | 
				
			
			@ -1955,12 +1955,18 @@ def _optimize_post(model):
 | 
			
		|||
            model.chat = MethodType(minicpmv_chat, model)
 | 
			
		||||
    elif model.config.model_type == "minicpmo":
 | 
			
		||||
        # vpm opt
 | 
			
		||||
        vpm_modeling_module_name = model.vpm.__class__.__module__
 | 
			
		||||
        vpm_module = importlib.import_module(vpm_modeling_module_name)
 | 
			
		||||
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
			
		||||
        convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
 | 
			
		||||
 | 
			
		||||
        if hasattr(model, "vpm"):
 | 
			
		||||
            vpm_modeling_module_name = model.vpm.__class__.__module__
 | 
			
		||||
            vpm_module = importlib.import_module(vpm_modeling_module_name)
 | 
			
		||||
            from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
			
		||||
            convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
 | 
			
		||||
        # apm opt
 | 
			
		||||
        if hasattr(model, "apm"):
 | 
			
		||||
            apm_modeling_module_name = model.apm.__class__.__module__
 | 
			
		||||
            apm_module = importlib.import_module(apm_modeling_module_name)
 | 
			
		||||
            from transformers.models.whisper.modeling_whisper import WhisperSdpaAttention
 | 
			
		||||
            from ipex_llm.transformers.models.whisper import whisper_attention_forward
 | 
			
		||||
            convert_forward(model.apm, WhisperSdpaAttention, whisper_attention_forward)
 | 
			
		||||
        # llm opt
 | 
			
		||||
        model.llm.config.model_type = "qwen2"
 | 
			
		||||
        _optimize_post(model.llm)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										103
									
								
								python/llm/src/ipex_llm/transformers/models/whisper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								python/llm/src/ipex_llm/transformers/models/whisper.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,103 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
 | 
			
		||||
# which is licensed under Apache License 2.0:
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from transformers.cache_utils import EncoderDecoderCache
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.utils import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def whisper_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    key_value_states: Optional[torch.Tensor] = None,
 | 
			
		||||
    past_key_value: Optional[EncoderDecoderCache] = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    layer_head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    invalidInputError(not output_attentions and layer_head_mask is None,
 | 
			
		||||
                      "`output_attentions` and `layer_head_mask` are not supported")
 | 
			
		||||
 | 
			
		||||
    # if key_value_states are provided this layer is used as a cross-attention layer
 | 
			
		||||
    # for the decoder
 | 
			
		||||
    is_cross_attention = key_value_states is not None
 | 
			
		||||
    bsz, tgt_len, _ = hidden_states.size()
 | 
			
		||||
 | 
			
		||||
    # get query proj
 | 
			
		||||
    query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        is_updated = past_key_value.is_updated.get(self.layer_idx)
 | 
			
		||||
        if is_cross_attention:
 | 
			
		||||
            past_key_value.is_updated[self.layer_idx] = True
 | 
			
		||||
            past_key_value = past_key_value.cross_attention_cache
 | 
			
		||||
        else:
 | 
			
		||||
            past_key_value = past_key_value.self_attention_cache
 | 
			
		||||
 | 
			
		||||
    # use key_value_states if cross attention
 | 
			
		||||
    current_states = key_value_states if key_value_states is not None else hidden_states
 | 
			
		||||
    if is_cross_attention and past_key_value and is_updated:
 | 
			
		||||
        # reuse k,v, cross_attentions
 | 
			
		||||
        key_states = past_key_value.key_cache[self.layer_idx]
 | 
			
		||||
        value_states = past_key_value.value_cache[self.layer_idx]
 | 
			
		||||
    else:
 | 
			
		||||
        key_states = self._shape(self.k_proj(current_states), -1, bsz)
 | 
			
		||||
        value_states = self._shape(self.v_proj(current_states), -1, bsz)
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            cache_position = cache_position if not is_cross_attention else None
 | 
			
		||||
            key_states, value_states = past_key_value.update(
 | 
			
		||||
                key_states, value_states, self.layer_idx, {"cache_position": cache_position}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdpa
 | 
			
		||||
    is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states,
 | 
			
		||||
        key_states.contiguous(),
 | 
			
		||||
        value_states.contiguous(),
 | 
			
		||||
        attention_mask,
 | 
			
		||||
        is_causal
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.out_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    return attn_output, None, past_key_value
 | 
			
		||||
		Loading…
	
		Reference in a new issue