[LLM]Reopen autotp generate_stream (#11120)
* reopen autotp generate_stream * fix style error * update
This commit is contained in:
parent
1dc680341b
commit
63e95698eb
4 changed files with 414 additions and 69 deletions
|
|
@ -58,6 +58,8 @@ If you successfully run the serving, you can get output like this:
|
|||
|
||||
We can use `curl` to test serving api
|
||||
|
||||
#### generate()
|
||||
|
||||
```bash
|
||||
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
|
||||
export http_proxy=
|
||||
|
|
@ -77,10 +79,68 @@ And you should get output like this:
|
|||
|
||||
```json
|
||||
{
|
||||
"generated_text": "What is AI? Artificial intelligence (AI) refers to the development of computer systems able to perform tasks that would normally require human intelligence, such as visual perception, speech",
|
||||
"generate_time": "0.45149803161621094s"
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "\n\nArtificial intelligence (AI) is a branch of computer science that deals with the creation of intelligent machines that can perform tasks that typically "
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
|
||||
```
|
||||
#### generate_stream()
|
||||
```bash
|
||||
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
|
||||
export http_proxy=
|
||||
export https_proxy=
|
||||
|
||||
curl -X 'POST' \
|
||||
'http://127.0.0.1:8000/generate_stream/' \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"prompt": "What is AI?",
|
||||
"n_predict": 32
|
||||
}'
|
||||
```
|
||||
|
||||
And you should get output like this:
|
||||
```json
|
||||
{"index": 0, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
|
||||
{"index": 1, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
|
||||
{"index": 2, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||
{"index": 3, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||
{"index": 4, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||
{"index": 5, "message": {"role": "assistant", "content": "Artificial "}, "finish_reason": null}
|
||||
{"index": 6, "message": {"role": "assistant", "content": "intelligence "}, "finish_reason": null}
|
||||
{"index": 7, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||
{"index": 8, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||
{"index": 9, "message": {"role": "assistant", "content": "(AI) "}, "finish_reason": null}
|
||||
{"index": 10, "message": {"role": "assistant", "content": "is "}, "finish_reason": null}
|
||||
{"index": 11, "message": {"role": "assistant", "content": "a "}, "finish_reason": null}
|
||||
{"index": 12, "message": {"role": "assistant", "content": "branch "}, "finish_reason": null}
|
||||
{"index": 13, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
|
||||
{"index": 14, "message": {"role": "assistant", "content": "computer "}, "finish_reason": null}
|
||||
{"index": 15, "message": {"role": "assistant", "content": "science "}, "finish_reason": null}
|
||||
{"index": 16, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
|
||||
{"index": 17, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||
{"index": 18, "message": {"role": "assistant", "content": "deals "}, "finish_reason": null}
|
||||
{"index": 19, "message": {"role": "assistant", "content": "with "}, "finish_reason": null}
|
||||
{"index": 20, "message": {"role": "assistant", "content": "the "}, "finish_reason": null}
|
||||
{"index": 21, "message": {"role": "assistant", "content": "creation "}, "finish_reason": null}
|
||||
{"index": 22, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
|
||||
{"index": 23, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||
{"index": 24, "message": {"role": "assistant", "content": "intelligent "}, "finish_reason": null}
|
||||
{"index": 25, "message": {"role": "assistant", "content": "machines "}, "finish_reason": null}
|
||||
{"index": 26, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
|
||||
{"index": 27, "message": {"role": "assistant", "content": "can "}, "finish_reason": null}
|
||||
{"index": 28, "message": {"role": "assistant", "content": "perform "}, "finish_reason": null}
|
||||
{"index": 29, "message": {"role": "assistant", "content": "tasks "}, "finish_reason": null}
|
||||
{"index": 30, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
|
||||
{"index": 31, "message": {"role": "assistant", "content": "typically "}, "finish_reason": null}
|
||||
{"index": 32, "message": {"role": "assistant", "content": null}, "finish_reason": "length"}
|
||||
|
||||
|
||||
```
|
||||
|
||||
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.
|
||||
|
|
|
|||
|
|
@ -18,17 +18,25 @@ import os
|
|||
import torch
|
||||
import transformers
|
||||
import time
|
||||
import json
|
||||
import argparse
|
||||
import torch.distributed as dist
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from threading import Thread
|
||||
from ipex_llm.transformers.streamer import BatchTextIteratorStreamer
|
||||
|
||||
|
||||
import asyncio, uuid
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
|
||||
|
|
@ -42,17 +50,31 @@ def get_int_from_env(env_keys, default):
|
|||
return val
|
||||
return int(default)
|
||||
|
||||
|
||||
global max_num_seqs
|
||||
global max_num_batched_tokens
|
||||
|
||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||
local_rank = get_int_from_env(["LOCAL_RANK", "PMI_RANK"], "0")
|
||||
world_size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE"], "1")
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||
|
||||
global model, tokenizer
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
n_predict: int = 32
|
||||
|
||||
|
||||
rest_req_deque = deque(maxlen=128)
|
||||
request_queue: asyncio.Queue = asyncio.Queue()
|
||||
result_dict: Dict[str, str] = {}
|
||||
streamer_dict = {}
|
||||
empty_req = PromptRequest(prompt="", n_predict=0)
|
||||
|
||||
|
||||
def load_model(model_path, low_bit):
|
||||
|
||||
from ipex_llm import optimize_model
|
||||
|
|
@ -61,7 +83,9 @@ def load_model(model_path, low_bit):
|
|||
import time
|
||||
import argparse
|
||||
|
||||
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
) # export AutoModelForCausalLM from transformers so that deepspeed use it
|
||||
from transformers import LlamaTokenizer, AutoTokenizer
|
||||
import deepspeed
|
||||
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
|
||||
|
|
@ -73,12 +97,14 @@ def load_model(model_path, low_bit):
|
|||
current_accel = CPU_Accelerator()
|
||||
set_accelerator(current_accel)
|
||||
global model, tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
device_map={"": "cpu"},
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
device_map={"": "cpu"},
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
model = deepspeed.init_inference(
|
||||
model,
|
||||
|
|
@ -89,70 +115,171 @@ def load_model(model_path, low_bit):
|
|||
|
||||
# Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
|
||||
# Convert the rest of the model into float16 to reduce allreduce traffic
|
||||
model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit).to(torch.float16)
|
||||
model = optimize_model(model.module.to(f"cpu"), low_bit=low_bit).to(torch.float16)
|
||||
|
||||
# Next, use XPU as accelerator to speed up inference
|
||||
current_accel = XPU_Accelerator()
|
||||
set_accelerator(current_accel)
|
||||
|
||||
# Move model back to xpu
|
||||
model = model.to(f'xpu:{local_rank}')
|
||||
model = model.to(f"xpu:{local_rank}")
|
||||
model = BenchmarkWrapper(model)
|
||||
|
||||
# Modify backend related settings
|
||||
# Modify backend related settings
|
||||
if world_size > 1:
|
||||
get_accelerator().set_device(local_rank)
|
||||
dist_backend = get_accelerator().communication_backend_name()
|
||||
import deepspeed.comm.comm
|
||||
|
||||
deepspeed.comm.comm.cdb = None
|
||||
from deepspeed.comm.comm import init_distributed
|
||||
|
||||
init_distributed()
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, trust_remote_code=True, padding_side="left"
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def generate_text(prompt: List[str], n_predict = 32):
|
||||
|
||||
async def generate_stream_gate(prompt: List[str], n_predict=32, request_ids=[]):
|
||||
while prompt[-1] == "":
|
||||
prompt = prompt[:-1]
|
||||
if isinstance(n_predict, list):
|
||||
n_predict = max(n_predict)
|
||||
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
|
||||
# print(input_ids)
|
||||
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
|
||||
output = model.generate(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=n_predict,
|
||||
use_cache=True)
|
||||
torch.xpu.synchronize()
|
||||
return output
|
||||
input_ids = inputs.input_ids.to(f"xpu:{local_rank}")
|
||||
attention_mask = inputs.attention_mask.to(f"xpu:{local_rank}")
|
||||
|
||||
for request_id in request_ids:
|
||||
if request_id not in streamer_dict:
|
||||
streamer_dict[request_id] = asyncio.Queue()
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
n_predict: int = 32
|
||||
streamer = BatchTextIteratorStreamer(
|
||||
tokenizer=tokenizer,
|
||||
timeout=600,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True,
|
||||
batch_size=len(prompt),
|
||||
)
|
||||
|
||||
generated_kwargs = dict(
|
||||
max_new_tokens=n_predict,
|
||||
min_new_tokens=n_predict,
|
||||
streamer=streamer,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
def model_generate():
|
||||
model.generate(input_ids, **generated_kwargs)
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.synchronize()
|
||||
|
||||
t1 = Thread(target=model_generate)
|
||||
t1.start()
|
||||
|
||||
stopped = False
|
||||
|
||||
async def put_item(queue, item):
|
||||
await queue.put(item)
|
||||
|
||||
for i in range(n_predict):
|
||||
tasks = []
|
||||
try:
|
||||
output_token = next(streamer)
|
||||
except StopIteration:
|
||||
stopped = True
|
||||
for index, request_id in enumerate(request_ids):
|
||||
task = asyncio.create_task(
|
||||
put_item(
|
||||
streamer_dict[request_id],
|
||||
(0 if stopped else n_predict - 1 - i, output_token[index]),
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
if stopped:
|
||||
break
|
||||
|
||||
empty_req = PromptRequest(prompt="", n_predict=0)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
from collections import deque
|
||||
rest_req_deque = deque(maxlen=128)
|
||||
request_queue: asyncio.Queue = asyncio.Queue()
|
||||
result_dict: Dict[str, str] = {}
|
||||
|
||||
async def stream_generator(token_queue, request_id):
|
||||
index = 0
|
||||
while True:
|
||||
if not token_queue.empty():
|
||||
remain, token = await token_queue.get()
|
||||
response = {
|
||||
"index": index,
|
||||
"message": {"role": "assistant", "content": token},
|
||||
"finish_reason": None,
|
||||
}
|
||||
yield json.dumps(response) + "\n"
|
||||
index = index + 1
|
||||
if remain == 0:
|
||||
response = {
|
||||
"index": index,
|
||||
"message": {"role": "assistant", "content": None},
|
||||
"finish_reason": "length",
|
||||
}
|
||||
yield json.dumps(response) + "\n"
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
streamer_dict.pop(request_id, None)
|
||||
|
||||
|
||||
async def generator(token_queue, request_id):
|
||||
while True:
|
||||
if not token_queue.empty():
|
||||
remain, token = await token_queue.get()
|
||||
yield token
|
||||
if remain == 0:
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
streamer_dict.pop(request_id, None)
|
||||
|
||||
|
||||
@app.post("/generate/")
|
||||
async def generate(prompt_request: PromptRequest):
|
||||
request_id = str(uuid.uuid4())
|
||||
await request_queue.put((request_id, prompt_request))
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
if request_id in result_dict:
|
||||
output_str = result_dict.pop(request_id)
|
||||
return {"generated_text": output_str}
|
||||
await asyncio.sleep(0)
|
||||
if request_id in streamer_dict:
|
||||
output_str = []
|
||||
token_queue = streamer_dict[request_id]
|
||||
async for item in generator(token_queue, request_id):
|
||||
output_str.append(item)
|
||||
|
||||
return {
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "".join(output_str),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
|
||||
|
||||
@app.post("/generate_stream/")
|
||||
async def generate_stream(prompt_request: PromptRequest):
|
||||
request_id = str(uuid.uuid4()) + "stream"
|
||||
await request_queue.put((request_id, prompt_request))
|
||||
while True:
|
||||
await asyncio.sleep(0)
|
||||
if request_id in streamer_dict:
|
||||
token_queue = streamer_dict[request_id]
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(token_queue, request_id), media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
async def process_requests():
|
||||
|
|
@ -164,7 +291,9 @@ async def process_requests():
|
|||
while rest_req_deque:
|
||||
request_id, rest_request = rest_req_deque.popleft()
|
||||
prompt = rest_request.prompt
|
||||
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
||||
cur_prompt_len = tokenizer(
|
||||
prompt_request.prompt, return_tensors="pt"
|
||||
).input_ids.size(1)
|
||||
cur_batched_tokens += cur_prompt_len
|
||||
if cur_batched_tokens > max_num_batched_tokens:
|
||||
cur_batched_tokens -= cur_prompt_len
|
||||
|
|
@ -179,9 +308,9 @@ async def process_requests():
|
|||
if request_queue.empty():
|
||||
break
|
||||
request_id, prompt_request = await request_queue.get()
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
||||
cur_prompt_len = tokenizer(
|
||||
prompt_request.prompt, return_tensors="pt"
|
||||
).input_ids.size(1)
|
||||
cur_batched_tokens += cur_prompt_len
|
||||
if cur_batched_tokens > max_num_batched_tokens:
|
||||
cur_batched_tokens -= cur_prompt_len
|
||||
|
|
@ -193,21 +322,28 @@ async def process_requests():
|
|||
if local_rank == 0 and prompt_requests:
|
||||
object_list = prompt_requests
|
||||
if len(object_list) < max_num_seqs:
|
||||
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
|
||||
logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
|
||||
object_list = object_list + [empty_req] * (
|
||||
max_num_seqs - len(object_list)
|
||||
)
|
||||
logger.info(
|
||||
f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}"
|
||||
)
|
||||
dist.broadcast_object_list(object_list, src=0)
|
||||
|
||||
start_time = time.time()
|
||||
outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
||||
await generate_stream_gate(
|
||||
[req.prompt for req in object_list],
|
||||
[req.n_predict for req in object_list],
|
||||
request_ids,
|
||||
)
|
||||
|
||||
generate_time = time.time() - start_time
|
||||
outputs = outputs.cpu()
|
||||
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_strs = output_strs[:len(prompt_requests)]
|
||||
|
||||
for request_id, output_str in zip(request_ids, output_strs):
|
||||
result_dict[request_id] = output_str
|
||||
logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
|
||||
logger.info(
|
||||
f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
|
|
@ -216,19 +352,44 @@ async def startup_event():
|
|||
asyncio.create_task(process_requests())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
|
||||
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--low-bit', type=str, default='sym_int4',
|
||||
help='The quantization type the model will convert to.')
|
||||
parser.add_argument('--port', type=int, default=8000,
|
||||
help='The port number on which the server will run.')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
|
||||
help='Max tokens can be batched by this service.')
|
||||
parser.add_argument('--max-num-seqs', type=int, default=8,
|
||||
help='Max requests can be batched by this service.')
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id-or-model-path",
|
||||
type=str,
|
||||
default="meta-llama/Llama-2-7b-chat-hf",
|
||||
help="The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded"
|
||||
", or the path to the huggingface checkpoint folder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--low-bit",
|
||||
type=str,
|
||||
default="sym_int4",
|
||||
help="The quantization type the model will convert to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="The port number on which the server will run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-num-batched-tokens",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Max tokens can be batched by this service.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-num-seqs",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Max requests can be batched by this service.",
|
||||
)
|
||||
|
||||
global max_num_seqs
|
||||
global max_num_batched_tokens
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
|
|
@ -236,10 +397,21 @@ if __name__ == "__main__":
|
|||
max_num_seqs = args.max_num_seqs
|
||||
max_num_batched_tokens = args.max_num_batched_tokens
|
||||
load_model(model_path, low_bit)
|
||||
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
if local_rank == 0:
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
await server.serve()
|
||||
else:
|
||||
while True:
|
||||
object_list = [None] * max_num_seqs
|
||||
dist.broadcast_object_list(object_list, src=0)
|
||||
output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
||||
await generate_stream_gate(
|
||||
[req.prompt for req in object_list],
|
||||
[req.n_predict for req in object_list],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -33,5 +33,4 @@ export TORCH_LLM_ALLREDUCE=0
|
|||
export WORLD_SIZE=2
|
||||
|
||||
mpirun -np $NUM_GPUS --prepend-rank \
|
||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
|
||||
|
||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'fp8' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
|
||||
114
python/llm/src/ipex_llm/transformers/streamer.py
Normal file
114
python/llm/src/ipex_llm/transformers/streamer.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||
#
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
|
||||
class BatchTextIteratorStreamer(TextIteratorStreamer):
|
||||
"""
|
||||
A specialized version of TextIteratorStreamer that handles text streams in batches, providing
|
||||
an efficient way to process large volumes of text data asynchronously. This class is designed
|
||||
to aggregate multiple texts into batches, making it ideal for applications that need to
|
||||
perform batch operations on streamed text data, such as bulk text processing or machine
|
||||
learning model inference in an interactive environment.
|
||||
|
||||
Parameters:
|
||||
tokenizer (`AutoTokenizer`):
|
||||
The tokenized used to decode the tokens.
|
||||
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||
Whether to skip the prompt to `.generate()` or not.
|
||||
timeout (`float`, *optional*):
|
||||
The timeout for the text queue. If `None`, the queue will
|
||||
block indefinitely. Useful to handle exceptions
|
||||
in `.generate()`, when it is called in a separate thread.
|
||||
decode_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||
batch_size(`int`)
|
||||
The size of the batches to process. This parameter must be specified and
|
||||
determines how many texts are processed together as a single batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
tokenizer: "AutoTokenizer",
|
||||
skip_prompt: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
**decode_kwargs
|
||||
):
|
||||
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
|
||||
self.batch_size = batch_size
|
||||
self.token_cache = [[] for _ in range(batch_size)]
|
||||
self.print_len = [0 for _ in range(batch_size)]
|
||||
self.generate_exception = None
|
||||
|
||||
def put(self, value):
|
||||
if len(value.shape) != 2:
|
||||
value = torch.reshape(
|
||||
value, (self.batch_size, value.shape[0] // self.batch_size)
|
||||
)
|
||||
|
||||
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||
self.next_tokens_are_prompt = False
|
||||
return
|
||||
|
||||
printable_texts = list()
|
||||
for idx in range(self.batch_size):
|
||||
self.token_cache[idx].extend(value[idx].tolist())
|
||||
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
|
||||
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len[idx]:]
|
||||
self.token_cache[idx] = []
|
||||
self.print_len[idx] = 0
|
||||
# If the last token is a CJK character, we print the characters.
|
||||
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
||||
printable_text = text[self.print_len[idx]:]
|
||||
self.print_len[idx] += len(printable_text)
|
||||
else:
|
||||
printable_text = text[self.print_len[idx]:text.rfind(" ") + 1]
|
||||
self.print_len[idx] += len(printable_text)
|
||||
printable_texts.append(printable_text)
|
||||
|
||||
self.on_finalized_text(printable_texts)
|
||||
|
||||
def end(self):
|
||||
printable_texts = list()
|
||||
for idx in range(self.batch_size):
|
||||
if len(self.token_cache[idx]) > 0:
|
||||
text = self.tokenizer.decode(
|
||||
self.token_cache[idx], **self.decode_kwargs
|
||||
)
|
||||
printable_text = text[self.print_len[idx]:]
|
||||
self.token_cache[idx] = []
|
||||
self.print_len[idx] = 0
|
||||
else:
|
||||
printable_text = ""
|
||||
printable_texts.append(printable_text)
|
||||
|
||||
self.next_tokens_are_prompt = True
|
||||
self.on_finalized_text(printable_texts, stream_end=True)
|
||||
|
||||
def on_finalized_text(self, texts: List[str], stream_end: bool = False):
|
||||
self.text_queue.put(texts, timeout=self.timeout)
|
||||
if stream_end:
|
||||
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
||||
Loading…
Reference in a new issue