support and optimize qwen2-audio (#11809)
This commit is contained in:
parent
3ac83f8396
commit
07b7f13982
2 changed files with 153 additions and 11 deletions
|
|
@ -1308,9 +1308,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
|
||||
convert_forward(model,
|
||||
module.Qwen2Model,
|
||||
qwen2_model_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2ForCausalLM,
|
||||
qwen2_causal_lm_forward)
|
||||
|
|
@ -1326,6 +1323,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.Qwen2SdpaAttention,
|
||||
qwen2_attention_forward)
|
||||
if version.parse(trans_version) >= version.parse("4.42"):
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward_4_42
|
||||
convert_forward(model, module.Qwen2Model, qwen2_model_forward_4_42)
|
||||
else:
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
|
||||
convert_forward(model, module.Qwen2Model, qwen2_model_forward)
|
||||
elif model.config.model_type == "qwen2_moe":
|
||||
# for Qwen1.5-MOE-A2.7B
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
@ -1356,6 +1359,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.Qwen2MoeSdpaAttention,
|
||||
qwen2_attention_forward)
|
||||
elif model.config.model_type == "qwen2_audio":
|
||||
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
|
||||
elif model.config.model_type == "cohere":
|
||||
# for CohereForAI/c4ai-command-r-v01
|
||||
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
|
||||
|
|
|
|||
|
|
@ -55,8 +55,6 @@ from ipex_llm.utils.common import invalidInputError
|
|||
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
|
||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers import logging
|
||||
|
|
@ -76,12 +74,15 @@ def qwen2_model_forward(
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None, # for transformers >= 4.42
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else \
|
||||
self.config.output_attentions
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
|
|
@ -90,8 +91,7 @@ def qwen2_model_forward(
|
|||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False,
|
||||
"You cannot specify both decoder_input_ids and "
|
||||
"decoder_inputs_embeds at the same time")
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
|
|
@ -159,6 +159,9 @@ def qwen2_model_forward(
|
|||
"the input. "
|
||||
)
|
||||
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||
|
||||
# ipex-llm changes start: don't generate `attention_mask` in specific cases
|
||||
if seq_length == 1 or batch_size == 1 and use_sdp_causal(
|
||||
seq_length, seq_length + past_key_values_length,
|
||||
|
|
@ -259,6 +262,138 @@ def qwen2_model_forward(
|
|||
)
|
||||
|
||||
|
||||
def qwen2_model_forward_4_42(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
invalidInputError(
|
||||
(input_ids is None) ^ (inputs_embeds is None),
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. "
|
||||
"Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# ipex-llm changes start
|
||||
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||
use_quantize_kv = (
|
||||
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
)
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
# ipex-llm changes end
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
# ipex-llm changes start: remove `to_legacy_cache`
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache
|
||||
# ipex-llm changes end
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache,
|
||||
all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def qwen2_causal_lm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
|
@ -271,6 +406,7 @@ def qwen2_causal_lm_forward(
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None, # for transformers >= 4.42
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
|
|
@ -293,6 +429,7 @@ def qwen2_causal_lm_forward(
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
|
|
|||
Loading…
Reference in a new issue