optimize minicpm-o's tts part (#12833)
This commit is contained in:
		
							parent
							
								
									f7b5a093a7
								
							
						
					
					
						commit
						8418450300
					
				
					 1 changed files with 6 additions and 0 deletions
				
			
		| 
						 | 
					@ -1032,6 +1032,9 @@ def _optimize_pre(model, qtype=None):
 | 
				
			||||||
        if hasattr(model, "vpm"):
 | 
					        if hasattr(model, "vpm"):
 | 
				
			||||||
            from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
					            from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
				
			||||||
            model.vpm.apply(merge_qkv)
 | 
					            model.vpm.apply(merge_qkv)
 | 
				
			||||||
 | 
					        # tts opt
 | 
				
			||||||
 | 
					        if hasattr(model, "tts"):
 | 
				
			||||||
 | 
					            _optimize_pre(model.tts.model, qtype=qtype)
 | 
				
			||||||
        # llm opt
 | 
					        # llm opt
 | 
				
			||||||
        model.llm.config.model_type = "qwen2"
 | 
					        model.llm.config.model_type = "qwen2"
 | 
				
			||||||
        _optimize_pre(model.llm, qtype=qtype)
 | 
					        _optimize_pre(model.llm, qtype=qtype)
 | 
				
			||||||
| 
						 | 
					@ -1971,6 +1974,9 @@ def _optimize_post(model):
 | 
				
			||||||
            from transformers.models.whisper.modeling_whisper import WhisperSdpaAttention
 | 
					            from transformers.models.whisper.modeling_whisper import WhisperSdpaAttention
 | 
				
			||||||
            from ipex_llm.transformers.models.whisper import whisper_attention_forward
 | 
					            from ipex_llm.transformers.models.whisper import whisper_attention_forward
 | 
				
			||||||
            convert_forward(model.apm, WhisperSdpaAttention, whisper_attention_forward)
 | 
					            convert_forward(model.apm, WhisperSdpaAttention, whisper_attention_forward)
 | 
				
			||||||
 | 
					        # tts opt
 | 
				
			||||||
 | 
					        if hasattr(model, "tts"):
 | 
				
			||||||
 | 
					            _optimize_post(model.tts.model)
 | 
				
			||||||
        # llm opt
 | 
					        # llm opt
 | 
				
			||||||
        model.llm.config.model_type = "qwen2"
 | 
					        model.llm.config.model_type = "qwen2"
 | 
				
			||||||
        _optimize_post(model.llm)
 | 
					        _optimize_post(model.llm)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue