Enable phi-4 with vision and audio (#13203)

* add phi4

* update

* enable audio

* update and add readme
This commit is contained in:
Wang, Jian4 2025-06-05 10:15:20 +08:00 committed by GitHub
parent e032156518
commit 45864790f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 183 additions and 0 deletions

View file

@ -0,0 +1,131 @@
import os
from dataclasses import asdict
from typing import NamedTuple, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, EngineArgs, SamplingParams
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
from vllm.assets.audio import AudioAsset
from vllm.utils import FlexibleArgumentParser
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
question_per_audio_count = {
0: "What is 1+1?",
1: "What is recited in the audio?",
2: "What sport and what nursery rhyme are referenced?"
}
model_path = "/llm/models/whisper-large-v3-turbo"
#model_path = "/llm/models/whisper-medium"
#model_path = "/llm/models/Phi-4-multimodal-instruct"
# Phi-4-multimodal-instruct
def run_phi4mm(question: str, audio_count: int):
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
return prompt
# Whisper
def run_whisper(question: str, audio_count: int):
assert audio_count == 1, (
"Whisper only support single audio input per prompt")
prompt = "<|startoftranscript|>"
return prompt
model_example_map = {
"phi4mm": run_phi4mm,
"whisper": run_whisper,
}
if "whisper" in model_path:
model_len=448
low_bit="fp16"
else:
model_len = 5500
low_bit="sym_int4"
def main(args):
audio_count = args.num_audios
llm = LLM(
model=model_path,
device="xpu",
dtype="float16",
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
mm_processor_kwargs=None,
load_in_low_bit=low_bit,
tensor_parallel_size=1,
max_num_seqs=8,
gpu_memory_utilization=0.95,
disable_async_output_proc=True,
distributed_executor_backend="ray",
max_model_len=model_len,
trust_remote_code=True,
block_size=8,
max_num_batched_tokens=model_len)
model = llm.llm_engine.model_config.hf_config.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
prompt = model_example_map[model](question_per_audio_count[audio_count], audio_count)
sampling_params = SamplingParams(temperature=0.1,
top_p=0.001,
repetition_penalty=1.05,
max_tokens=128,
skip_special_tokens=False
)
mm_data = {}
if audio_count > 0:
mm_data = {
"audio": [
asset.audio_and_sample_rate
for asset in audio_assets[:audio_count]
]
}
assert args.num_prompts > 0
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
if args.num_prompts > 1:
# Batch inference
inputs = [inputs] * args.num_prompts
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')
parser.add_argument("--num-audios",
type=int,
default=1,
choices=[0, 1, 2],
help="Number of audio items per prompt.")
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args)

View file

@ -10,6 +10,7 @@ model_path = "/llm/models/glm-4v-9b"
model_path = "/llm/models/InternVL2-8B"
model_path = "/llm/models/gemma-3-12b-it"
model_path = "/llm/models/Qwen2.5-VL-7B-Instruct"
model_path = "/llm/models/Phi-4-multimodal-instruct"
prompt = "What is in the image?"
@ -77,6 +78,18 @@ def run_qwen2_vl(question, modality):
stop_token_ids = None
return prompt, stop_token_ids
# Phi-4-multimodal-instruct
def run_phi4mm(question, modality):
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process image inputs.
"""
assert modality == "image"
prompt = f"<|user|><|image_1|>{question}<|end|><|assistant|>"
stop_token_ids = None
return prompt, stop_token_ids
model_example_map = {
"minicpmv": run_minicpmv,
"qwen2_vl": run_qwen2_vl,
@ -85,6 +98,7 @@ model_example_map = {
"chatglm": run_glm4v,
"internvl_chat": run_internvl,
"gemma3": run_gemma3,
"phi4mm": run_phi4mm,
}
if "glm-4v" in model_path:

View file

@ -438,6 +438,8 @@ docker logs CONTAINER_NAME
## 8. Advanced Features
#### Multi-modal Model
##### Vision model
<details>
vLLM serving with IPEX-LLM supports multi-modal models, such as [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6), which can accept image and text input at the same time and respond.
@ -478,6 +480,40 @@ curl http://localhost:8000/v1/chat/completions \
```
</details>
##### Audio model
<details>
vLLM serving with IPEX-LLM supports multi-modal models, such as [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct)(only offine now) and whisper series model([whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [whisper-medium](https://huggingface.co/openai/whisper-medium)), which can accept audio input and respond text output.
Offline test:
```bash
export VLLM_USE_V1=0
python3 audio_language.py
```
Online test:
1. Start vLLM service: change the `model` and `served_model_name` value in `/llm/start-vllm-service.sh`
2. Download or get a audio file first.
```python
# python3 load.py
from vllm.assets.audio import AudioAsset
import soundfile as sf
audio, sr = AudioAsset("winning_call").audio_and_sample_rate
sf.write("output.wav", audio, sr)
```
3. Send request with audio file and prompt text(optional).
```bash
curl http://localhost:8000/v1/audio/transcriptions \
-H "Content-Type: multipart/form-data" \
-F file="@/llm/models/test/output.wav" \
-F model="whisper-large-v3-turbo"
```
</details>
#### Preifx Caching
<details>
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.

View file

@ -129,6 +129,8 @@ def get_load_function(low_bit):
if "glm-4v" in self.vllm_config.model_config.model.lower() and \
low_bit in ("sym_int4", "woq_int4"):
modules = ["dense_4h_to_h"]
if "phi4mm" in self.vllm_config.model_config.hf_config.model_type:
modules = ["vision_encoder", "embed_tokens_extend"]
if low_bit == "fp16":
# to fix qwen2.5-vl and glm-4v
modules = ["vision", "visual"]