Add lightweight-serving whisper asr example (#11847)
* add asr init * update for pp * update style * update readme * update reamde
This commit is contained in:
parent
a8e2573421
commit
5c4ed00593
6 changed files with 177 additions and 54 deletions
|
|
@ -22,6 +22,10 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||||
# for internlm-xcomposer2-vl-7b
|
# for internlm-xcomposer2-vl-7b
|
||||||
pip install transformers==4.31.0
|
pip install transformers==4.31.0
|
||||||
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
|
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
|
||||||
|
|
||||||
|
# for whisper-large-v3
|
||||||
|
pip install transformers==4.36.2
|
||||||
|
pip install datasets soundfile librosa # required by audio processing
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 1.2 Installation on Windows
|
#### 1.2 Installation on Windows
|
||||||
|
|
@ -35,6 +39,14 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
|
||||||
pip install fastapi uvicorn openai
|
pip install fastapi uvicorn openai
|
||||||
pip install gradio # for gradio web UI
|
pip install gradio # for gradio web UI
|
||||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||||
|
|
||||||
|
# for internlm-xcomposer2-vl-7b
|
||||||
|
pip install transformers==4.31.0
|
||||||
|
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
|
||||||
|
|
||||||
|
# for whisper-large-v3
|
||||||
|
pip install transformers==4.36.2
|
||||||
|
pip install datasets soundfile librosa # required by audio processing
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Configures OneAPI environment variables for Linux
|
### 2. Configures OneAPI environment variables for Linux
|
||||||
|
|
@ -180,7 +192,7 @@ curl http://localhost:8000/v1/chat/completions \
|
||||||
|
|
||||||
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
|
image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
|
||||||
```bash
|
```bash
|
||||||
wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
|
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
|
||||||
curl http://localhost:8000/v1/chat/completions \
|
curl http://localhost:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
|
|
@ -219,6 +231,17 @@ curl http://localhost:8000/v1/completions \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### v1/audio/transcriptions
|
||||||
|
|
||||||
|
ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. And `whisper-large-v3` just can be used to transcription audio. The audio file_type should be supported by `librosa.load`.
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/v1/audio/transcriptions \
|
||||||
|
-H "Content-Type: multipart/form-data" \
|
||||||
|
-F file="@/llm/test.mp3" \
|
||||||
|
-F model="whisper-large-v3" \
|
||||||
|
-F languag="zh"
|
||||||
|
```
|
||||||
|
|
||||||
### 6. Benchmark with wrk
|
### 6. Benchmark with wrk
|
||||||
|
|
||||||
Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details
|
Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,19 @@ async def main():
|
||||||
model_path = args.repo_id_or_model_path
|
model_path = args.repo_id_or_model_path
|
||||||
low_bit = args.low_bit
|
low_bit = args.low_bit
|
||||||
|
|
||||||
|
processor = None
|
||||||
|
if "whisper" not in model_path.lower():
|
||||||
local_model = ModelWorker(model_path, low_bit)
|
local_model = ModelWorker(model_path, low_bit)
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
myapp = FastApp(local_model, tokenizer)
|
else:
|
||||||
|
local_model = ModelWorker(model_path, low_bit, "audio", torch_dtype=torch.float32)
|
||||||
|
from transformers import WhisperProcessor
|
||||||
|
processor = WhisperProcessor.from_pretrained(model_path)
|
||||||
|
tokenizer = processor.tokenizer
|
||||||
|
myapp = FastApp(local_model, tokenizer, processor)
|
||||||
config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
|
config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
await server.serve()
|
await server.serve()
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ import uuid
|
||||||
from typing import List, Optional, Union, Dict
|
from typing import List, Optional, Union, Dict
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from .tgi_protocol import Parameters
|
from .tgi_protocol import Parameters
|
||||||
|
from typing_extensions import Literal
|
||||||
|
from fastapi import File, UploadFile, Form
|
||||||
from .openai_protocol import (
|
from .openai_protocol import (
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse,
|
ChatCompletionStreamResponse,
|
||||||
|
|
@ -38,6 +40,8 @@ from .openai_protocol import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
|
TranscriptionRequest,
|
||||||
|
TranscriptionResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_dict: Dict[str, str] = {}
|
result_dict: Dict[str, str] = {}
|
||||||
|
|
@ -50,6 +54,7 @@ class InputsRequest(BaseModel):
|
||||||
image_list: Optional[list] = None
|
image_list: Optional[list] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
req_type: str = 'completion'
|
req_type: str = 'completion'
|
||||||
|
transcription_request: Optional[TranscriptionRequest] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
|
@ -92,20 +97,27 @@ app.add_middleware(
|
||||||
|
|
||||||
global tokenizer
|
global tokenizer
|
||||||
global local_model
|
global local_model
|
||||||
|
global processor
|
||||||
|
|
||||||
|
|
||||||
class FastApp():
|
class FastApp():
|
||||||
def __init__(self, model, mytokenizer):
|
def __init__(self, model, mytokenizer, myprocessor=None):
|
||||||
global tokenizer
|
global tokenizer
|
||||||
global local_model
|
global local_model
|
||||||
|
global processor
|
||||||
local_model = model
|
local_model = model
|
||||||
tokenizer = mytokenizer
|
tokenizer = mytokenizer
|
||||||
|
processor = myprocessor
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
|
|
||||||
def get_queue_next_token(delta_text_queue):
|
def get_queue_next_token(delta_text_queue):
|
||||||
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
|
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
|
||||||
delta_text = delta_text_queue.text_queue.get(timeout=timeout)
|
delta_text = delta_text_queue.text_queue.get(timeout=timeout)
|
||||||
|
if "whisper" in local_model.model_name.lower():
|
||||||
|
if delta_text is not None and "<|" in delta_text and "|>" in delta_text:
|
||||||
|
import re
|
||||||
|
delta_text = re.sub(r'<\|.*?\|>', '', delta_text)
|
||||||
if delta_text is None:
|
if delta_text is None:
|
||||||
remain = 0
|
remain = 0
|
||||||
else:
|
else:
|
||||||
|
|
@ -385,6 +397,32 @@ async def create_completion(request: CompletionRequest):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions")
|
||||||
|
async def transcriptions(
|
||||||
|
file: UploadFile=File(...),
|
||||||
|
model: Optional[str]=Form("default_model"),
|
||||||
|
language: Optional[str]=Form("zh"),
|
||||||
|
prompt: Optional[str]=Form(None),
|
||||||
|
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]]=Form(None),
|
||||||
|
temperature: Optional[float]=Form(None),
|
||||||
|
timestamp_granularities: Optional[List[Literal["word", "segment"]]]=Form(None)
|
||||||
|
):
|
||||||
|
file_path = "./" + file.filename
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(await file.read())
|
||||||
|
inputs_request = InputsRequest(
|
||||||
|
inputs="transcriptions",
|
||||||
|
parameters=None,
|
||||||
|
stream=False,
|
||||||
|
req_type="completion",
|
||||||
|
transcription_request=TranscriptionRequest(file=file_path, model=model, language=language)
|
||||||
|
)
|
||||||
|
request_id, result = await generate(inputs_request)
|
||||||
|
rsp = TranscriptionResponse(text=result)
|
||||||
|
return rsp
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
asyncio.create_task(process_requests(local_model, result_dict))
|
asyncio.create_task(process_requests(local_model, result_dict))
|
||||||
|
|
@ -393,4 +431,4 @@ async def startup_event():
|
||||||
async def process_requests(local_model, result_dict):
|
async def process_requests(local_model, result_dict):
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await local_model.process_step(tokenizer, result_dict)
|
await local_model.process_step(tokenizer, result_dict, processor)
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,12 @@ logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ModelWorker:
|
class ModelWorker:
|
||||||
def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
|
def __init__(self, checkpoint, low_bit, model_type="normal", torch_dtype=torch.float16):
|
||||||
self.dtype = torch_dtype
|
self.dtype = torch_dtype
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
if model_type == "audio":
|
||||||
|
self.model = self.load_model(checkpoint, low_bit, "audio")
|
||||||
|
else:
|
||||||
model = self.load_model(checkpoint, low_bit)
|
model = self.load_model(checkpoint, low_bit)
|
||||||
from ipex_llm.utils import BenchmarkWrapper
|
from ipex_llm.utils import BenchmarkWrapper
|
||||||
self.model = BenchmarkWrapper(model, do_print=True)
|
self.model = BenchmarkWrapper(model, do_print=True)
|
||||||
|
|
@ -35,7 +38,16 @@ class ModelWorker:
|
||||||
self.streamer = {}
|
self.streamer = {}
|
||||||
self.model_name = checkpoint
|
self.model_name = checkpoint
|
||||||
|
|
||||||
def load_model(self, model_path, low_bit='sym_int4'):
|
def load_model(self, model_path, low_bit='sym_int4', model_type="normal"):
|
||||||
|
if model_type == "audio":
|
||||||
|
from ipex_llm.transformers import AutoModelForSpeechSeq2Seq
|
||||||
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path,
|
||||||
|
load_in_low_bit=low_bit,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
optimize_model=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_cache=True)
|
||||||
|
else:
|
||||||
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
|
|
@ -54,6 +66,26 @@ class ModelWorker:
|
||||||
model = model.eval().to("xpu")
|
model = model.eval().to("xpu")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
async def add_asr_request(self, processor):
|
||||||
|
if self.waiting_requests.empty():
|
||||||
|
return
|
||||||
|
tmp_result = await self.waiting_requests.get()
|
||||||
|
request_id, request = tmp_result
|
||||||
|
transcription_request = request.transcription_request
|
||||||
|
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
||||||
|
language=transcription_request.language, task="transcribe")
|
||||||
|
audio_path = transcription_request.file
|
||||||
|
import librosa
|
||||||
|
raw_speech, sampling_rate = librosa.load(audio_path,
|
||||||
|
sr=processor.feature_extractor.sampling_rate)
|
||||||
|
input_features = processor(
|
||||||
|
raw_speech,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_attention_mask=True,
|
||||||
|
).input_features.to('xpu')
|
||||||
|
return input_features, forced_decoder_ids, request_id
|
||||||
|
|
||||||
async def add_request(self, tokenizer):
|
async def add_request(self, tokenizer):
|
||||||
if self.waiting_requests.empty():
|
if self.waiting_requests.empty():
|
||||||
return
|
return
|
||||||
|
|
@ -91,8 +123,17 @@ class ModelWorker:
|
||||||
return input_ids, parameters, request_id, inputs_embeds
|
return input_ids, parameters, request_id, inputs_embeds
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
async def process_step(self, tokenizer, result_dict):
|
async def process_step(self, tokenizer, result_dict, processor=None):
|
||||||
if not self.waiting_requests.empty():
|
if not self.waiting_requests.empty():
|
||||||
|
if processor is not None and "whisper" in self.model_name.lower():
|
||||||
|
input_features, decoder_ids, request_id = await self.add_asr_request(processor)
|
||||||
|
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
||||||
|
|
||||||
|
def model_generate():
|
||||||
|
self.model.generate(input_features,
|
||||||
|
streamer=self.streamer[request_id],
|
||||||
|
forced_decoder_ids=decoder_ids)
|
||||||
|
else:
|
||||||
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
|
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
|
||||||
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
||||||
|
|
||||||
|
|
@ -117,7 +158,6 @@ class ModelWorker:
|
||||||
streamer=self.streamer[request_id], **generate_kwargs)
|
streamer=self.streamer[request_id], **generate_kwargs)
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
t1 = Thread(target=model_generate)
|
t1 = Thread(target=model_generate)
|
||||||
t1.start()
|
t1.start()
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from openai.types.chat import ChatCompletionMessageParam
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
# from vllm.sampling_params import SamplingParams
|
# from vllm.sampling_params import SamplingParams
|
||||||
|
|
@ -31,6 +32,20 @@ def random_uuid() -> str:
|
||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionRequest(BaseModel):
|
||||||
|
file: str = None
|
||||||
|
model: Optional[str] = "default_model"
|
||||||
|
language: Optional[str] = "zh"
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResponse(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
class OpenAIBaseModel(BaseModel):
|
class OpenAIBaseModel(BaseModel):
|
||||||
# OpenAI API does not allow extra fields
|
# OpenAI API does not allow extra fields
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
|
||||||
|
|
@ -800,7 +800,7 @@ class PPModelWorker:
|
||||||
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
|
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
|
||||||
await asyncio.gather(*_stream_tasks)
|
await asyncio.gather(*_stream_tasks)
|
||||||
|
|
||||||
async def process_step(self, tokenizer, result_dict):
|
async def process_step(self, tokenizer, result_dict, processor=None):
|
||||||
cur_batch = None
|
cur_batch = None
|
||||||
torch.xpu.synchronize(self.device)
|
torch.xpu.synchronize(self.device)
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue