optimize minicpm-o's tts part (#12833)

This commit is contained in:
Yishuo Wang 2025-02-17 14:53:37 +08:00 committed by GitHub
parent f7b5a093a7
commit 8418450300
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1032,6 +1032,9 @@ def _optimize_pre(model, qtype=None):
if hasattr(model, "vpm"):
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.vpm.apply(merge_qkv)
# tts opt
if hasattr(model, "tts"):
_optimize_pre(model.tts.model, qtype=qtype)
# llm opt
model.llm.config.model_type = "qwen2"
_optimize_pre(model.llm, qtype=qtype)
@ -1971,6 +1974,9 @@ def _optimize_post(model):
from transformers.models.whisper.modeling_whisper import WhisperSdpaAttention
from ipex_llm.transformers.models.whisper import whisper_attention_forward
convert_forward(model.apm, WhisperSdpaAttention, whisper_attention_forward)
# tts opt
if hasattr(model, "tts"):
_optimize_post(model.tts.model)
# llm opt
model.llm.config.model_type = "qwen2"
_optimize_post(model.llm)