add audio optimization for qwen2.5-omni (#13037)

This commit is contained in:
Yishuo Wang 2025-04-07 17:20:26 +08:00 committed by GitHub
parent 7548c12b2c
commit ef852dcb4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 182 additions and 4 deletions

View file

@ -2072,12 +2072,31 @@ def _optimize_post(model):
convert_forward(model.thinker.visual, module.Qwen2_5OmniVisionSdpaAttention,
qwen2_5_omni_vision_attention_forward)
# audio opt
from ipex_llm.transformers.models.qwen2_5_omni import qwen2_5_omni_audio_attention_forward
convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioAttention,
qwen2_5_omni_audio_attention_forward)
convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioSdpaAttention,
qwen2_5_omni_audio_attention_forward)
# tts opt
if hasattr(model, "talker"):
convert_forward(model.talker, module.Qwen2_5OmniAttention,
if model.has_talker:
# talker part
convert_forward(model.talker.model, module.Qwen2_5OmniAttention,
qwen2_5_omni_attention_forward)
convert_forward(model.talker, module.Qwen2_5OmniThinkerModel,
convert_forward(model.talker.model, module.Qwen2_5OmniSdpaAttention,
qwen2_5_omni_attention_forward)
convert_forward(model.talker.model, module.Qwen2_5OmniTalkerModel,
qwen2_5_omni_thinker_model_forward)
convert_forward(model.talker.model, module.Qwen2MLP, qwen2_mlp_forward)
# token2wav part
from ipex_llm.transformers.models.qwen2_5_omni import dit_attention_forward
from ipex_llm.transformers.models.qwen2_5_omni import _create_block_diff
convert_forward(model.token2wav, module.DiTAttention, dit_attention_forward)
dit_model = model.token2wav.code2wav_dit_model
dit_model._create_block_diff = MethodType(_create_block_diff, dit_model)
return model

View file

@ -20,9 +20,11 @@
import math
import torch
from typing import Optional, Tuple, List, Union
from transformers.cache_utils import Cache
from transformers.cache_utils import Cache, EncoderDecoderCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniAttention
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb_vision
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_multimodal_rotary_pos_emb
@ -284,3 +286,160 @@ def qwen2_5_omni_vision_attention_forward(
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
def qwen2_5_omni_audio_attention_forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
cu_seqlens: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
seq_length, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
query_states = query_states.reshape(seq_length, self.num_heads, -1)
seq_lens = cu_seqlens.tolist()
invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length,
"unexpected input")
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id,
# we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
if layer_head_mask is None and use_sdp_non_causal(query_states.size(-1),
query_states.device, query_states.dtype):
kv_length = key_states.size(0)
padding_kv_length = (kv_length + 128 - 1) // 128 * 128
attention_mask = torch.full(
[1, 1, seq_length, padding_kv_length], torch.finfo(query_states.dtype).min,
device=query_states.device, dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0
q = query_states.transpose(0, 1).unsqueeze(0)
k = key_states.transpose(0, 1).unsqueeze(0).contiguous()
v = value_states.transpose(0, 1).unsqueeze(0).contiguous()
# q, k, v: [1, num_heads, seq_length, head_dim]
attn_weights = None
attn_output = scaled_dot_product_attention(q, k, v, attention_mask, False)
attn_output = attn_output.permute(0, 2, 1, 3).squeeze(0)
# attn_output: [seq_length, num_heads, head_dim]
else:
attention_mask = torch.full(
[1, seq_length, key_states.size(0)], torch.finfo(query_states.dtype).min,
device=query_states.device, dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0
query_states = query_states.transpose(0, 1)
key_states = key_states.transpose(0, 1)
value_states = value_states.transpose(0, 1)
attn_weights = torch.matmul(query_states,
key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights)
if layer_head_mask is not None:
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state`s
# because `attn_output` can be partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(seq_length, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
def dit_attention_forward(
self,
x,
rope=None,
mask=None,
) -> torch.Tensor:
batch_size = x.shape[0]
# `sample` projections.
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# apply rotary position embedding
# Due to training process, only first head is applied with RoPE, will be fixed at next release
cos, sin = rope
query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin)
if use_sdp_non_causal(head_dim, query.device, query.dtype):
mask = torch.where(mask, 0, torch.finfo(query.dtype).min)
x = scaled_dot_product_attention(query, key.contiguous(), value.contiguous(), mask, False)
x = x.transpose(1, 2)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation]
x, _ = attention_interface(self, query, key, value, attention_mask=mask, is_causal=False)
# mask
x = x.reshape(batch_size, -1, self.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = self.to_out[0](x)
# dropout
x = self.to_out[1](x)
return x
def _create_block_diff(self, x):
batch, seq_len = x.shape[0], x.shape[1]
block_indices = torch.arange(seq_len, device=x.device) // self.block_size
block_i = block_indices.unsqueeze(1) # [seq_length, 1]
block_j = block_indices.unsqueeze(0) # [1, seq_length]
block_diff = block_j - block_i # (n, n)
return block_diff.unsqueeze(0).unsqueeze(0)