Add lightweight-serving whisper asr example (#11847)
* add asr init * update for pp * update style * update readme * update reamde
This commit is contained in:
		
							parent
							
								
									a8e2573421
								
							
						
					
					
						commit
						5c4ed00593
					
				
					 6 changed files with 177 additions and 54 deletions
				
			
		| 
						 | 
				
			
			@ -22,6 +22,10 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
			
		|||
# for internlm-xcomposer2-vl-7b
 | 
			
		||||
pip install transformers==4.31.0
 | 
			
		||||
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
 | 
			
		||||
 | 
			
		||||
# for whisper-large-v3
 | 
			
		||||
pip install transformers==4.36.2
 | 
			
		||||
pip install datasets soundfile librosa # required by audio processing
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 1.2 Installation on Windows
 | 
			
		||||
| 
						 | 
				
			
			@ -35,6 +39,14 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
 | 
			
		|||
pip install fastapi uvicorn openai
 | 
			
		||||
pip install gradio # for gradio web UI
 | 
			
		||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
			
		||||
 | 
			
		||||
# for internlm-xcomposer2-vl-7b
 | 
			
		||||
pip install transformers==4.31.0
 | 
			
		||||
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
 | 
			
		||||
 | 
			
		||||
# for whisper-large-v3
 | 
			
		||||
pip install transformers==4.36.2
 | 
			
		||||
pip install datasets soundfile librosa # required by audio processing
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Configures OneAPI environment variables for Linux
 | 
			
		||||
| 
						 | 
				
			
			@ -180,7 +192,7 @@ curl http://localhost:8000/v1/chat/completions \
 | 
			
		|||
 | 
			
		||||
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
 | 
			
		||||
```bash
 | 
			
		||||
wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
 | 
			
		||||
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
 | 
			
		||||
curl http://localhost:8000/v1/chat/completions \
 | 
			
		||||
  -H "Content-Type: application/json" \
 | 
			
		||||
  -d '{
 | 
			
		||||
| 
						 | 
				
			
			@ -219,6 +231,17 @@ curl http://localhost:8000/v1/completions \
 | 
			
		|||
  }'
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### v1/audio/transcriptions
 | 
			
		||||
 | 
			
		||||
ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. And `whisper-large-v3` just can be used to transcription audio. The audio file_type should be supported by `librosa.load`.
 | 
			
		||||
```bash
 | 
			
		||||
curl http://localhost:8000/v1/audio/transcriptions \
 | 
			
		||||
  -H "Content-Type: multipart/form-data" \
 | 
			
		||||
  -F file="@/llm/test.mp3" \
 | 
			
		||||
  -F model="whisper-large-v3" \
 | 
			
		||||
  -F languag="zh"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 6. Benchmark with wrk
 | 
			
		||||
 | 
			
		||||
Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,12 +39,19 @@ async def main():
 | 
			
		|||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
    low_bit = args.low_bit
 | 
			
		||||
 | 
			
		||||
    processor = None
 | 
			
		||||
    if "whisper" not in model_path.lower():
 | 
			
		||||
        local_model = ModelWorker(model_path, low_bit)
 | 
			
		||||
        # Load tokenizer
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
			
		||||
        if tokenizer.pad_token is None:
 | 
			
		||||
            tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
    myapp = FastApp(local_model, tokenizer)
 | 
			
		||||
    else:
 | 
			
		||||
        local_model = ModelWorker(model_path, low_bit, "audio", torch_dtype=torch.float32)
 | 
			
		||||
        from transformers import WhisperProcessor
 | 
			
		||||
        processor = WhisperProcessor.from_pretrained(model_path)
 | 
			
		||||
        tokenizer = processor.tokenizer
 | 
			
		||||
    myapp = FastApp(local_model, tokenizer, processor)
 | 
			
		||||
    config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
 | 
			
		||||
    server = uvicorn.Server(config)
 | 
			
		||||
    await server.serve()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,6 +27,8 @@ import uuid
 | 
			
		|||
from typing import List, Optional, Union, Dict
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from .tgi_protocol import Parameters
 | 
			
		||||
from typing_extensions import Literal
 | 
			
		||||
from fastapi import File, UploadFile, Form
 | 
			
		||||
from .openai_protocol import (
 | 
			
		||||
    ChatCompletionResponseStreamChoice,
 | 
			
		||||
    ChatCompletionStreamResponse,
 | 
			
		||||
| 
						 | 
				
			
			@ -38,6 +40,8 @@ from .openai_protocol import (
 | 
			
		|||
    CompletionResponse,
 | 
			
		||||
    CompletionResponseStreamChoice,
 | 
			
		||||
    CompletionStreamResponse,
 | 
			
		||||
    TranscriptionRequest,
 | 
			
		||||
    TranscriptionResponse,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
result_dict: Dict[str, str] = {}
 | 
			
		||||
| 
						 | 
				
			
			@ -50,6 +54,7 @@ class InputsRequest(BaseModel):
 | 
			
		|||
    image_list: Optional[list] = None
 | 
			
		||||
    stream: Optional[bool] = False
 | 
			
		||||
    req_type: str = 'completion'
 | 
			
		||||
    transcription_request:  Optional[TranscriptionRequest] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionRequest(BaseModel):
 | 
			
		||||
| 
						 | 
				
			
			@ -92,20 +97,27 @@ app.add_middleware(
 | 
			
		|||
 | 
			
		||||
global tokenizer
 | 
			
		||||
global local_model
 | 
			
		||||
global processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FastApp():
 | 
			
		||||
    def __init__(self, model, mytokenizer):
 | 
			
		||||
    def __init__(self, model, mytokenizer, myprocessor=None):
 | 
			
		||||
        global tokenizer
 | 
			
		||||
        global local_model
 | 
			
		||||
        global processor
 | 
			
		||||
        local_model = model
 | 
			
		||||
        tokenizer = mytokenizer
 | 
			
		||||
        processor = myprocessor
 | 
			
		||||
        self.app = app
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_queue_next_token(delta_text_queue):
 | 
			
		||||
    timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
 | 
			
		||||
    delta_text = delta_text_queue.text_queue.get(timeout=timeout)
 | 
			
		||||
    if "whisper" in local_model.model_name.lower():
 | 
			
		||||
        if delta_text is not None and "<|" in delta_text and "|>" in delta_text:
 | 
			
		||||
            import re
 | 
			
		||||
            delta_text = re.sub(r'<\|.*?\|>', '', delta_text)
 | 
			
		||||
    if delta_text is None:
 | 
			
		||||
        remain = 0
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -385,6 +397,32 @@ async def create_completion(request: CompletionRequest):
 | 
			
		|||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/v1/audio/transcriptions")
 | 
			
		||||
async def transcriptions(
 | 
			
		||||
    file: UploadFile=File(...),
 | 
			
		||||
    model: Optional[str]=Form("default_model"),
 | 
			
		||||
    language: Optional[str]=Form("zh"),
 | 
			
		||||
    prompt: Optional[str]=Form(None),
 | 
			
		||||
    response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]]=Form(None),
 | 
			
		||||
    temperature: Optional[float]=Form(None),
 | 
			
		||||
    timestamp_granularities: Optional[List[Literal["word", "segment"]]]=Form(None)
 | 
			
		||||
):
 | 
			
		||||
    file_path = "./" + file.filename
 | 
			
		||||
    if not os.path.exists(file_path):
 | 
			
		||||
        with open(file_path, "wb") as f:
 | 
			
		||||
            f.write(await file.read())
 | 
			
		||||
    inputs_request = InputsRequest(
 | 
			
		||||
        inputs="transcriptions",
 | 
			
		||||
        parameters=None,
 | 
			
		||||
        stream=False,
 | 
			
		||||
        req_type="completion",
 | 
			
		||||
        transcription_request=TranscriptionRequest(file=file_path, model=model, language=language)
 | 
			
		||||
    )
 | 
			
		||||
    request_id, result = await generate(inputs_request)
 | 
			
		||||
    rsp = TranscriptionResponse(text=result)
 | 
			
		||||
    return rsp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.on_event("startup")
 | 
			
		||||
async def startup_event():
 | 
			
		||||
    asyncio.create_task(process_requests(local_model, result_dict))
 | 
			
		||||
| 
						 | 
				
			
			@ -393,4 +431,4 @@ async def startup_event():
 | 
			
		|||
async def process_requests(local_model, result_dict):
 | 
			
		||||
    while True:
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        await local_model.process_step(tokenizer, result_dict)
 | 
			
		||||
        await local_model.process_step(tokenizer, result_dict, processor)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,9 +23,12 @@ logger = logging.get_logger(__name__)
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ModelWorker:
 | 
			
		||||
    def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
 | 
			
		||||
    def __init__(self, checkpoint, low_bit, model_type="normal", torch_dtype=torch.float16):
 | 
			
		||||
        self.dtype = torch_dtype
 | 
			
		||||
        start = time.perf_counter()
 | 
			
		||||
        if model_type == "audio":
 | 
			
		||||
            self.model = self.load_model(checkpoint, low_bit, "audio")
 | 
			
		||||
        else:
 | 
			
		||||
            model = self.load_model(checkpoint, low_bit)
 | 
			
		||||
            from ipex_llm.utils import BenchmarkWrapper
 | 
			
		||||
            self.model = BenchmarkWrapper(model, do_print=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -35,7 +38,16 @@ class ModelWorker:
 | 
			
		|||
        self.streamer = {}
 | 
			
		||||
        self.model_name = checkpoint
 | 
			
		||||
 | 
			
		||||
    def load_model(self, model_path, low_bit='sym_int4'):
 | 
			
		||||
    def load_model(self, model_path, low_bit='sym_int4', model_type="normal"):
 | 
			
		||||
        if model_type == "audio":
 | 
			
		||||
            from ipex_llm.transformers import AutoModelForSpeechSeq2Seq
 | 
			
		||||
            model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path,
 | 
			
		||||
                                                              load_in_low_bit=low_bit,
 | 
			
		||||
                                                              torch_dtype=self.dtype,
 | 
			
		||||
                                                              optimize_model=True,
 | 
			
		||||
                                                              trust_remote_code=True,
 | 
			
		||||
                                                              use_cache=True)
 | 
			
		||||
        else:
 | 
			
		||||
            from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
 | 
			
		||||
            try:
 | 
			
		||||
                model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
| 
						 | 
				
			
			@ -54,6 +66,26 @@ class ModelWorker:
 | 
			
		|||
        model = model.eval().to("xpu")
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    async def add_asr_request(self, processor):
 | 
			
		||||
        if self.waiting_requests.empty():
 | 
			
		||||
            return
 | 
			
		||||
        tmp_result = await self.waiting_requests.get()
 | 
			
		||||
        request_id, request = tmp_result
 | 
			
		||||
        transcription_request = request.transcription_request
 | 
			
		||||
        forced_decoder_ids = processor.get_decoder_prompt_ids(
 | 
			
		||||
            language=transcription_request.language, task="transcribe")
 | 
			
		||||
        audio_path = transcription_request.file
 | 
			
		||||
        import librosa
 | 
			
		||||
        raw_speech, sampling_rate = librosa.load(audio_path,
 | 
			
		||||
                                                 sr=processor.feature_extractor.sampling_rate)
 | 
			
		||||
        input_features = processor(
 | 
			
		||||
            raw_speech,
 | 
			
		||||
            sampling_rate=sampling_rate,
 | 
			
		||||
            return_tensors="pt",
 | 
			
		||||
            return_attention_mask=True,
 | 
			
		||||
        ).input_features.to('xpu')
 | 
			
		||||
        return input_features, forced_decoder_ids, request_id
 | 
			
		||||
 | 
			
		||||
    async def add_request(self, tokenizer):
 | 
			
		||||
        if self.waiting_requests.empty():
 | 
			
		||||
            return
 | 
			
		||||
| 
						 | 
				
			
			@ -91,8 +123,17 @@ class ModelWorker:
 | 
			
		|||
        return input_ids, parameters, request_id, inputs_embeds
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    async def process_step(self, tokenizer, result_dict):
 | 
			
		||||
    async def process_step(self, tokenizer, result_dict, processor=None):
 | 
			
		||||
        if not self.waiting_requests.empty():
 | 
			
		||||
            if processor is not None and "whisper" in self.model_name.lower():
 | 
			
		||||
                input_features, decoder_ids, request_id = await self.add_asr_request(processor)
 | 
			
		||||
                self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
 | 
			
		||||
 | 
			
		||||
                def model_generate():
 | 
			
		||||
                    self.model.generate(input_features,
 | 
			
		||||
                                        streamer=self.streamer[request_id],
 | 
			
		||||
                                        forced_decoder_ids=decoder_ids)
 | 
			
		||||
            else:
 | 
			
		||||
                input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
 | 
			
		||||
                self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -117,7 +158,6 @@ class ModelWorker:
 | 
			
		|||
                                            streamer=self.streamer[request_id], **generate_kwargs)
 | 
			
		||||
            torch.xpu.empty_cache()
 | 
			
		||||
            torch.xpu.synchronize()
 | 
			
		||||
 | 
			
		||||
            from threading import Thread
 | 
			
		||||
            t1 = Thread(target=model_generate)
 | 
			
		||||
            t1.start()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,7 @@ from openai.types.chat import ChatCompletionMessageParam
 | 
			
		|||
from pydantic import BaseModel, ConfigDict, Field, model_validator
 | 
			
		||||
from typing_extensions import Annotated
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from typing_extensions import Literal
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# from vllm.sampling_params import SamplingParams
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +32,20 @@ def random_uuid() -> str:
 | 
			
		|||
    return str(uuid.uuid4().hex)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TranscriptionRequest(BaseModel):
 | 
			
		||||
    file: str = None
 | 
			
		||||
    model: Optional[str] = "default_model"
 | 
			
		||||
    language: Optional[str] = "zh"
 | 
			
		||||
    prompt: Optional[str] = None
 | 
			
		||||
    response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = None
 | 
			
		||||
    temperature: Optional[float] = None
 | 
			
		||||
    timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TranscriptionResponse(BaseModel):
 | 
			
		||||
    text: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OpenAIBaseModel(BaseModel):
 | 
			
		||||
    # OpenAI API does not allow extra fields
 | 
			
		||||
    model_config = ConfigDict(extra="forbid")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -800,7 +800,7 @@ class PPModelWorker:
 | 
			
		|||
                    _stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
 | 
			
		||||
        await asyncio.gather(*_stream_tasks)
 | 
			
		||||
 | 
			
		||||
    async def process_step(self, tokenizer, result_dict):
 | 
			
		||||
    async def process_step(self, tokenizer, result_dict, processor=None):
 | 
			
		||||
        cur_batch = None
 | 
			
		||||
        torch.xpu.synchronize(self.device)
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue