add audio optimization for qwen2.5-omni (#13037)
This commit is contained in:
parent
7548c12b2c
commit
ef852dcb4a
2 changed files with 182 additions and 4 deletions
|
|
@ -2072,12 +2072,31 @@ def _optimize_post(model):
|
|||
convert_forward(model.thinker.visual, module.Qwen2_5OmniVisionSdpaAttention,
|
||||
qwen2_5_omni_vision_attention_forward)
|
||||
|
||||
# audio opt
|
||||
from ipex_llm.transformers.models.qwen2_5_omni import qwen2_5_omni_audio_attention_forward
|
||||
convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioAttention,
|
||||
qwen2_5_omni_audio_attention_forward)
|
||||
convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioSdpaAttention,
|
||||
qwen2_5_omni_audio_attention_forward)
|
||||
|
||||
# tts opt
|
||||
if hasattr(model, "talker"):
|
||||
convert_forward(model.talker, module.Qwen2_5OmniAttention,
|
||||
if model.has_talker:
|
||||
# talker part
|
||||
convert_forward(model.talker.model, module.Qwen2_5OmniAttention,
|
||||
qwen2_5_omni_attention_forward)
|
||||
convert_forward(model.talker, module.Qwen2_5OmniThinkerModel,
|
||||
convert_forward(model.talker.model, module.Qwen2_5OmniSdpaAttention,
|
||||
qwen2_5_omni_attention_forward)
|
||||
convert_forward(model.talker.model, module.Qwen2_5OmniTalkerModel,
|
||||
qwen2_5_omni_thinker_model_forward)
|
||||
convert_forward(model.talker.model, module.Qwen2MLP, qwen2_mlp_forward)
|
||||
|
||||
# token2wav part
|
||||
from ipex_llm.transformers.models.qwen2_5_omni import dit_attention_forward
|
||||
from ipex_llm.transformers.models.qwen2_5_omni import _create_block_diff
|
||||
convert_forward(model.token2wav, module.DiTAttention, dit_attention_forward)
|
||||
dit_model = model.token2wav.code2wav_dit_model
|
||||
dit_model._create_block_diff = MethodType(_create_block_diff, dit_model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,11 @@
|
|||
import math
|
||||
import torch
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.cache_utils import Cache, EncoderDecoderCache
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniAttention
|
||||
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb
|
||||
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb_vision
|
||||
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_multimodal_rotary_pos_emb
|
||||
|
||||
|
|
@ -284,3 +286,160 @@ def qwen2_5_omni_vision_attention_forward(
|
|||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def qwen2_5_omni_audio_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
seq_length, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
seq_lens = cu_seqlens.tolist()
|
||||
invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length,
|
||||
"unexpected input")
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id,
|
||||
# we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
if layer_head_mask is None and use_sdp_non_causal(query_states.size(-1),
|
||||
query_states.device, query_states.dtype):
|
||||
kv_length = key_states.size(0)
|
||||
padding_kv_length = (kv_length + 128 - 1) // 128 * 128
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, padding_kv_length], torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device, dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0
|
||||
|
||||
q = query_states.transpose(0, 1).unsqueeze(0)
|
||||
k = key_states.transpose(0, 1).unsqueeze(0).contiguous()
|
||||
v = value_states.transpose(0, 1).unsqueeze(0).contiguous()
|
||||
# q, k, v: [1, num_heads, seq_length, head_dim]
|
||||
|
||||
attn_weights = None
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attention_mask, False)
|
||||
attn_output = attn_output.permute(0, 2, 1, 3).squeeze(0)
|
||||
# attn_output: [seq_length, num_heads, head_dim]
|
||||
else:
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, key_states.size(0)], torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device, dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0
|
||||
|
||||
query_states = query_states.transpose(0, 1)
|
||||
key_states = key_states.transpose(0, 1)
|
||||
value_states = value_states.transpose(0, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state`s
|
||||
# because `attn_output` can be partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(seq_length, self.embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def dit_attention_forward(
|
||||
self,
|
||||
x,
|
||||
rope=None,
|
||||
mask=None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // self.heads
|
||||
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# apply rotary position embedding
|
||||
# Due to training process, only first head is applied with RoPE, will be fixed at next release
|
||||
cos, sin = rope
|
||||
query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin)
|
||||
|
||||
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
||||
mask = torch.where(mask, 0, torch.finfo(query.dtype).min)
|
||||
x = scaled_dot_product_attention(query, key.contiguous(), value.contiguous(), mask, False)
|
||||
x = x.transpose(1, 2)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation]
|
||||
x, _ = attention_interface(self, query, key, value, attention_mask=mask, is_causal=False)
|
||||
|
||||
# mask
|
||||
x = x.reshape(batch_size, -1, self.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
x = self.to_out[0](x)
|
||||
# dropout
|
||||
x = self.to_out[1](x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _create_block_diff(self, x):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
block_indices = torch.arange(seq_len, device=x.device) // self.block_size
|
||||
|
||||
block_i = block_indices.unsqueeze(1) # [seq_length, 1]
|
||||
block_j = block_indices.unsqueeze(0) # [1, seq_length]
|
||||
|
||||
block_diff = block_j - block_i # (n, n)
|
||||
return block_diff.unsqueeze(0).unsqueeze(0)
|
||||
|
|
|
|||
Loading…
Reference in a new issue