[LLM]Reopen autotp generate_stream (#11120)
* reopen autotp generate_stream * fix style error * update
This commit is contained in:
parent
1dc680341b
commit
63e95698eb
4 changed files with 414 additions and 69 deletions
|
|
@ -58,6 +58,8 @@ If you successfully run the serving, you can get output like this:
|
||||||
|
|
||||||
We can use `curl` to test serving api
|
We can use `curl` to test serving api
|
||||||
|
|
||||||
|
#### generate()
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
|
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
|
||||||
export http_proxy=
|
export http_proxy=
|
||||||
|
|
@ -77,10 +79,68 @@ And you should get output like this:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"generated_text": "What is AI? Artificial intelligence (AI) refers to the development of computer systems able to perform tasks that would normally require human intelligence, such as visual perception, speech",
|
"index": 0,
|
||||||
"generate_time": "0.45149803161621094s"
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "\n\nArtificial intelligence (AI) is a branch of computer science that deals with the creation of intelligent machines that can perform tasks that typically "
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
#### generate_stream()
|
||||||
|
```bash
|
||||||
|
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
|
||||||
|
export http_proxy=
|
||||||
|
export https_proxy=
|
||||||
|
|
||||||
|
curl -X 'POST' \
|
||||||
|
'http://127.0.0.1:8000/generate_stream/' \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"prompt": "What is AI?",
|
||||||
|
"n_predict": 32
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
And you should get output like this:
|
||||||
|
```json
|
||||||
|
{"index": 0, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
|
||||||
|
{"index": 1, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
|
||||||
|
{"index": 2, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||||
|
{"index": 3, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||||
|
{"index": 4, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||||
|
{"index": 5, "message": {"role": "assistant", "content": "Artificial "}, "finish_reason": null}
|
||||||
|
{"index": 6, "message": {"role": "assistant", "content": "intelligence "}, "finish_reason": null}
|
||||||
|
{"index": 7, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||||
|
{"index": 8, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||||
|
{"index": 9, "message": {"role": "assistant", "content": "(AI) "}, "finish_reason": null}
|
||||||
|
{"index": 10, "message": {"role": "assistant", "content": "is "}, "finish_reason": null}
|
||||||
|
{"index": 11, "message": {"role": "assistant", "content": "a "}, "finish_reason": null}
|
||||||
|
{"index": 12, "message": {"role": "assistant", "content": "branch "}, "finish_reason": null}
|
||||||
|
{"index": 13, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
|
||||||
|
{"index": 14, "message": {"role": "assistant", "content": "computer "}, "finish_reason": null}
|
||||||
|
{"index": 15, "message": {"role": "assistant", "content": "science "}, "finish_reason": null}
|
||||||
|
{"index": 16, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
|
||||||
|
{"index": 17, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||||
|
{"index": 18, "message": {"role": "assistant", "content": "deals "}, "finish_reason": null}
|
||||||
|
{"index": 19, "message": {"role": "assistant", "content": "with "}, "finish_reason": null}
|
||||||
|
{"index": 20, "message": {"role": "assistant", "content": "the "}, "finish_reason": null}
|
||||||
|
{"index": 21, "message": {"role": "assistant", "content": "creation "}, "finish_reason": null}
|
||||||
|
{"index": 22, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
|
||||||
|
{"index": 23, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
|
||||||
|
{"index": 24, "message": {"role": "assistant", "content": "intelligent "}, "finish_reason": null}
|
||||||
|
{"index": 25, "message": {"role": "assistant", "content": "machines "}, "finish_reason": null}
|
||||||
|
{"index": 26, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
|
||||||
|
{"index": 27, "message": {"role": "assistant", "content": "can "}, "finish_reason": null}
|
||||||
|
{"index": 28, "message": {"role": "assistant", "content": "perform "}, "finish_reason": null}
|
||||||
|
{"index": 29, "message": {"role": "assistant", "content": "tasks "}, "finish_reason": null}
|
||||||
|
{"index": 30, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
|
||||||
|
{"index": 31, "message": {"role": "assistant", "content": "typically "}, "finish_reason": null}
|
||||||
|
{"index": 32, "message": {"role": "assistant", "content": null}, "finish_reason": "length"}
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.
|
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.
|
||||||
|
|
|
||||||
|
|
@ -18,17 +18,25 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from threading import Thread
|
||||||
|
from ipex_llm.transformers.streamer import BatchTextIteratorStreamer
|
||||||
|
|
||||||
|
|
||||||
import asyncio, uuid
|
import asyncio, uuid
|
||||||
|
from collections import deque
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
|
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
|
||||||
|
|
@ -42,17 +50,31 @@ def get_int_from_env(env_keys, default):
|
||||||
return val
|
return val
|
||||||
return int(default)
|
return int(default)
|
||||||
|
|
||||||
|
|
||||||
global max_num_seqs
|
global max_num_seqs
|
||||||
global max_num_batched_tokens
|
global max_num_batched_tokens
|
||||||
|
|
||||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
local_rank = get_int_from_env(["LOCAL_RANK", "PMI_RANK"], "0")
|
||||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
world_size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE"], "1")
|
||||||
os.environ["RANK"] = str(local_rank)
|
os.environ["RANK"] = str(local_rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||||
|
|
||||||
global model, tokenizer
|
global model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRequest(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
n_predict: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
rest_req_deque = deque(maxlen=128)
|
||||||
|
request_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
result_dict: Dict[str, str] = {}
|
||||||
|
streamer_dict = {}
|
||||||
|
empty_req = PromptRequest(prompt="", n_predict=0)
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path, low_bit):
|
def load_model(model_path, low_bit):
|
||||||
|
|
||||||
from ipex_llm import optimize_model
|
from ipex_llm import optimize_model
|
||||||
|
|
@ -61,7 +83,9 @@ def load_model(model_path, low_bit):
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
) # export AutoModelForCausalLM from transformers so that deepspeed use it
|
||||||
from transformers import LlamaTokenizer, AutoTokenizer
|
from transformers import LlamaTokenizer, AutoTokenizer
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
|
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
|
||||||
|
|
@ -73,12 +97,14 @@ def load_model(model_path, low_bit):
|
||||||
current_accel = CPU_Accelerator()
|
current_accel = CPU_Accelerator()
|
||||||
set_accelerator(current_accel)
|
set_accelerator(current_accel)
|
||||||
global model, tokenizer
|
global model, tokenizer
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
device_map={"": "cpu"},
|
model_path,
|
||||||
low_cpu_mem_usage=True,
|
device_map={"": "cpu"},
|
||||||
torch_dtype=torch.float16,
|
low_cpu_mem_usage=True,
|
||||||
trust_remote_code=True,
|
torch_dtype=torch.float16,
|
||||||
use_cache=True)
|
trust_remote_code=True,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
model = deepspeed.init_inference(
|
model = deepspeed.init_inference(
|
||||||
model,
|
model,
|
||||||
|
|
@ -89,70 +115,171 @@ def load_model(model_path, low_bit):
|
||||||
|
|
||||||
# Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
|
# Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
|
||||||
# Convert the rest of the model into float16 to reduce allreduce traffic
|
# Convert the rest of the model into float16 to reduce allreduce traffic
|
||||||
model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit).to(torch.float16)
|
model = optimize_model(model.module.to(f"cpu"), low_bit=low_bit).to(torch.float16)
|
||||||
|
|
||||||
# Next, use XPU as accelerator to speed up inference
|
# Next, use XPU as accelerator to speed up inference
|
||||||
current_accel = XPU_Accelerator()
|
current_accel = XPU_Accelerator()
|
||||||
set_accelerator(current_accel)
|
set_accelerator(current_accel)
|
||||||
|
|
||||||
# Move model back to xpu
|
# Move model back to xpu
|
||||||
model = model.to(f'xpu:{local_rank}')
|
model = model.to(f"xpu:{local_rank}")
|
||||||
model = BenchmarkWrapper(model)
|
model = BenchmarkWrapper(model)
|
||||||
|
|
||||||
# Modify backend related settings
|
# Modify backend related settings
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
get_accelerator().set_device(local_rank)
|
get_accelerator().set_device(local_rank)
|
||||||
dist_backend = get_accelerator().communication_backend_name()
|
dist_backend = get_accelerator().communication_backend_name()
|
||||||
import deepspeed.comm.comm
|
import deepspeed.comm.comm
|
||||||
|
|
||||||
deepspeed.comm.comm.cdb = None
|
deepspeed.comm.comm.cdb = None
|
||||||
from deepspeed.comm.comm import init_distributed
|
from deepspeed.comm.comm import init_distributed
|
||||||
|
|
||||||
init_distributed()
|
init_distributed()
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path, trust_remote_code=True, padding_side="left"
|
||||||
|
)
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
def generate_text(prompt: List[str], n_predict = 32):
|
|
||||||
|
async def generate_stream_gate(prompt: List[str], n_predict=32, request_ids=[]):
|
||||||
while prompt[-1] == "":
|
while prompt[-1] == "":
|
||||||
prompt = prompt[:-1]
|
prompt = prompt[:-1]
|
||||||
if isinstance(n_predict, list):
|
if isinstance(n_predict, list):
|
||||||
n_predict = max(n_predict)
|
n_predict = max(n_predict)
|
||||||
|
|
||||||
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
||||||
input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
|
input_ids = inputs.input_ids.to(f"xpu:{local_rank}")
|
||||||
# print(input_ids)
|
attention_mask = inputs.attention_mask.to(f"xpu:{local_rank}")
|
||||||
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
|
|
||||||
output = model.generate(input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
max_new_tokens=n_predict,
|
|
||||||
use_cache=True)
|
|
||||||
torch.xpu.synchronize()
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
for request_id in request_ids:
|
||||||
|
if request_id not in streamer_dict:
|
||||||
|
streamer_dict[request_id] = asyncio.Queue()
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
streamer = BatchTextIteratorStreamer(
|
||||||
prompt: str
|
tokenizer=tokenizer,
|
||||||
n_predict: int = 32
|
timeout=600,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
batch_size=len(prompt),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_kwargs = dict(
|
||||||
|
max_new_tokens=n_predict,
|
||||||
|
min_new_tokens=n_predict,
|
||||||
|
streamer=streamer,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
do_sample=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_generate():
|
||||||
|
model.generate(input_ids, **generated_kwargs)
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
|
||||||
|
t1 = Thread(target=model_generate)
|
||||||
|
t1.start()
|
||||||
|
|
||||||
|
stopped = False
|
||||||
|
|
||||||
|
async def put_item(queue, item):
|
||||||
|
await queue.put(item)
|
||||||
|
|
||||||
|
for i in range(n_predict):
|
||||||
|
tasks = []
|
||||||
|
try:
|
||||||
|
output_token = next(streamer)
|
||||||
|
except StopIteration:
|
||||||
|
stopped = True
|
||||||
|
for index, request_id in enumerate(request_ids):
|
||||||
|
task = asyncio.create_task(
|
||||||
|
put_item(
|
||||||
|
streamer_dict[request_id],
|
||||||
|
(0 if stopped else n_predict - 1 - i, output_token[index]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
if stopped:
|
||||||
|
break
|
||||||
|
|
||||||
empty_req = PromptRequest(prompt="", n_predict=0)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
rest_req_deque = deque(maxlen=128)
|
async def stream_generator(token_queue, request_id):
|
||||||
request_queue: asyncio.Queue = asyncio.Queue()
|
index = 0
|
||||||
result_dict: Dict[str, str] = {}
|
while True:
|
||||||
|
if not token_queue.empty():
|
||||||
|
remain, token = await token_queue.get()
|
||||||
|
response = {
|
||||||
|
"index": index,
|
||||||
|
"message": {"role": "assistant", "content": token},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
yield json.dumps(response) + "\n"
|
||||||
|
index = index + 1
|
||||||
|
if remain == 0:
|
||||||
|
response = {
|
||||||
|
"index": index,
|
||||||
|
"message": {"role": "assistant", "content": None},
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
yield json.dumps(response) + "\n"
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
streamer_dict.pop(request_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def generator(token_queue, request_id):
|
||||||
|
while True:
|
||||||
|
if not token_queue.empty():
|
||||||
|
remain, token = await token_queue.get()
|
||||||
|
yield token
|
||||||
|
if remain == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
streamer_dict.pop(request_id, None)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate/")
|
@app.post("/generate/")
|
||||||
async def generate(prompt_request: PromptRequest):
|
async def generate(prompt_request: PromptRequest):
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
await request_queue.put((request_id, prompt_request))
|
await request_queue.put((request_id, prompt_request))
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0)
|
||||||
if request_id in result_dict:
|
if request_id in streamer_dict:
|
||||||
output_str = result_dict.pop(request_id)
|
output_str = []
|
||||||
return {"generated_text": output_str}
|
token_queue = streamer_dict[request_id]
|
||||||
|
async for item in generator(token_queue, request_id):
|
||||||
|
output_str.append(item)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "".join(output_str),
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_stream/")
|
||||||
|
async def generate_stream(prompt_request: PromptRequest):
|
||||||
|
request_id = str(uuid.uuid4()) + "stream"
|
||||||
|
await request_queue.put((request_id, prompt_request))
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if request_id in streamer_dict:
|
||||||
|
token_queue = streamer_dict[request_id]
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_generator(token_queue, request_id), media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_requests():
|
async def process_requests():
|
||||||
|
|
@ -164,7 +291,9 @@ async def process_requests():
|
||||||
while rest_req_deque:
|
while rest_req_deque:
|
||||||
request_id, rest_request = rest_req_deque.popleft()
|
request_id, rest_request = rest_req_deque.popleft()
|
||||||
prompt = rest_request.prompt
|
prompt = rest_request.prompt
|
||||||
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
cur_prompt_len = tokenizer(
|
||||||
|
prompt_request.prompt, return_tensors="pt"
|
||||||
|
).input_ids.size(1)
|
||||||
cur_batched_tokens += cur_prompt_len
|
cur_batched_tokens += cur_prompt_len
|
||||||
if cur_batched_tokens > max_num_batched_tokens:
|
if cur_batched_tokens > max_num_batched_tokens:
|
||||||
cur_batched_tokens -= cur_prompt_len
|
cur_batched_tokens -= cur_prompt_len
|
||||||
|
|
@ -179,9 +308,9 @@ async def process_requests():
|
||||||
if request_queue.empty():
|
if request_queue.empty():
|
||||||
break
|
break
|
||||||
request_id, prompt_request = await request_queue.get()
|
request_id, prompt_request = await request_queue.get()
|
||||||
# import pdb
|
cur_prompt_len = tokenizer(
|
||||||
# pdb.set_trace()
|
prompt_request.prompt, return_tensors="pt"
|
||||||
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
).input_ids.size(1)
|
||||||
cur_batched_tokens += cur_prompt_len
|
cur_batched_tokens += cur_prompt_len
|
||||||
if cur_batched_tokens > max_num_batched_tokens:
|
if cur_batched_tokens > max_num_batched_tokens:
|
||||||
cur_batched_tokens -= cur_prompt_len
|
cur_batched_tokens -= cur_prompt_len
|
||||||
|
|
@ -193,21 +322,28 @@ async def process_requests():
|
||||||
if local_rank == 0 and prompt_requests:
|
if local_rank == 0 and prompt_requests:
|
||||||
object_list = prompt_requests
|
object_list = prompt_requests
|
||||||
if len(object_list) < max_num_seqs:
|
if len(object_list) < max_num_seqs:
|
||||||
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
|
object_list = object_list + [empty_req] * (
|
||||||
logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
|
max_num_seqs - len(object_list)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}"
|
||||||
|
)
|
||||||
dist.broadcast_object_list(object_list, src=0)
|
dist.broadcast_object_list(object_list, src=0)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
await generate_stream_gate(
|
||||||
|
[req.prompt for req in object_list],
|
||||||
|
[req.n_predict for req in object_list],
|
||||||
|
request_ids,
|
||||||
|
)
|
||||||
|
|
||||||
generate_time = time.time() - start_time
|
generate_time = time.time() - start_time
|
||||||
outputs = outputs.cpu()
|
|
||||||
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
||||||
output_strs = output_strs[:len(prompt_requests)]
|
|
||||||
|
|
||||||
for request_id, output_str in zip(request_ids, output_strs):
|
logger.info(
|
||||||
result_dict[request_id] = output_str
|
f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}"
|
||||||
logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
|
)
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
|
|
@ -216,19 +352,44 @@ async def startup_event():
|
||||||
asyncio.create_task(process_requests())
|
asyncio.create_task(process_requests())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
async def main():
|
||||||
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
|
description="Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP"
|
||||||
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
|
)
|
||||||
', or the path to the huggingface checkpoint folder')
|
parser.add_argument(
|
||||||
parser.add_argument('--low-bit', type=str, default='sym_int4',
|
"--repo-id-or-model-path",
|
||||||
help='The quantization type the model will convert to.')
|
type=str,
|
||||||
parser.add_argument('--port', type=int, default=8000,
|
default="meta-llama/Llama-2-7b-chat-hf",
|
||||||
help='The port number on which the server will run.')
|
help="The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded"
|
||||||
parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
|
", or the path to the huggingface checkpoint folder",
|
||||||
help='Max tokens can be batched by this service.')
|
)
|
||||||
parser.add_argument('--max-num-seqs', type=int, default=8,
|
parser.add_argument(
|
||||||
help='Max requests can be batched by this service.')
|
"--low-bit",
|
||||||
|
type=str,
|
||||||
|
default="sym_int4",
|
||||||
|
help="The quantization type the model will convert to.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="The port number on which the server will run.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-num-batched-tokens",
|
||||||
|
type=int,
|
||||||
|
default=4096,
|
||||||
|
help="Max tokens can be batched by this service.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-num-seqs",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Max requests can be batched by this service.",
|
||||||
|
)
|
||||||
|
|
||||||
|
global max_num_seqs
|
||||||
|
global max_num_batched_tokens
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
model_path = args.repo_id_or_model_path
|
model_path = args.repo_id_or_model_path
|
||||||
|
|
@ -236,10 +397,21 @@ if __name__ == "__main__":
|
||||||
max_num_seqs = args.max_num_seqs
|
max_num_seqs = args.max_num_seqs
|
||||||
max_num_batched_tokens = args.max_num_batched_tokens
|
max_num_batched_tokens = args.max_num_batched_tokens
|
||||||
load_model(model_path, low_bit)
|
load_model(model_path, low_bit)
|
||||||
|
|
||||||
|
config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
await server.serve()
|
||||||
else:
|
else:
|
||||||
while True:
|
while True:
|
||||||
object_list = [None] * max_num_seqs
|
object_list = [None] * max_num_seqs
|
||||||
dist.broadcast_object_list(object_list, src=0)
|
dist.broadcast_object_list(object_list, src=0)
|
||||||
output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
await generate_stream_gate(
|
||||||
|
[req.prompt for req in object_list],
|
||||||
|
[req.n_predict for req in object_list],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -33,5 +33,4 @@ export TORCH_LLM_ALLREDUCE=0
|
||||||
export WORLD_SIZE=2
|
export WORLD_SIZE=2
|
||||||
|
|
||||||
mpirun -np $NUM_GPUS --prepend-rank \
|
mpirun -np $NUM_GPUS --prepend-rank \
|
||||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
|
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'fp8' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
|
||||||
|
|
||||||
114
python/llm/src/ipex_llm/transformers/streamer.py
Normal file
114
python/llm/src/ipex_llm/transformers/streamer.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
|
||||||
|
class BatchTextIteratorStreamer(TextIteratorStreamer):
|
||||||
|
"""
|
||||||
|
A specialized version of TextIteratorStreamer that handles text streams in batches, providing
|
||||||
|
an efficient way to process large volumes of text data asynchronously. This class is designed
|
||||||
|
to aggregate multiple texts into batches, making it ideal for applications that need to
|
||||||
|
perform batch operations on streamed text data, such as bulk text processing or machine
|
||||||
|
learning model inference in an interactive environment.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tokenizer (`AutoTokenizer`):
|
||||||
|
The tokenized used to decode the tokens.
|
||||||
|
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to skip the prompt to `.generate()` or not.
|
||||||
|
timeout (`float`, *optional*):
|
||||||
|
The timeout for the text queue. If `None`, the queue will
|
||||||
|
block indefinitely. Useful to handle exceptions
|
||||||
|
in `.generate()`, when it is called in a separate thread.
|
||||||
|
decode_kwargs (`dict`, *optional*):
|
||||||
|
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||||
|
batch_size(`int`)
|
||||||
|
The size of the batches to process. This parameter must be specified and
|
||||||
|
determines how many texts are processed together as a single batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
tokenizer: "AutoTokenizer",
|
||||||
|
skip_prompt: bool = False,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
**decode_kwargs
|
||||||
|
):
|
||||||
|
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.token_cache = [[] for _ in range(batch_size)]
|
||||||
|
self.print_len = [0 for _ in range(batch_size)]
|
||||||
|
self.generate_exception = None
|
||||||
|
|
||||||
|
def put(self, value):
|
||||||
|
if len(value.shape) != 2:
|
||||||
|
value = torch.reshape(
|
||||||
|
value, (self.batch_size, value.shape[0] // self.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||||
|
self.next_tokens_are_prompt = False
|
||||||
|
return
|
||||||
|
|
||||||
|
printable_texts = list()
|
||||||
|
for idx in range(self.batch_size):
|
||||||
|
self.token_cache[idx].extend(value[idx].tolist())
|
||||||
|
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
|
||||||
|
|
||||||
|
if text.endswith("\n"):
|
||||||
|
printable_text = text[self.print_len[idx]:]
|
||||||
|
self.token_cache[idx] = []
|
||||||
|
self.print_len[idx] = 0
|
||||||
|
# If the last token is a CJK character, we print the characters.
|
||||||
|
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
||||||
|
printable_text = text[self.print_len[idx]:]
|
||||||
|
self.print_len[idx] += len(printable_text)
|
||||||
|
else:
|
||||||
|
printable_text = text[self.print_len[idx]:text.rfind(" ") + 1]
|
||||||
|
self.print_len[idx] += len(printable_text)
|
||||||
|
printable_texts.append(printable_text)
|
||||||
|
|
||||||
|
self.on_finalized_text(printable_texts)
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
printable_texts = list()
|
||||||
|
for idx in range(self.batch_size):
|
||||||
|
if len(self.token_cache[idx]) > 0:
|
||||||
|
text = self.tokenizer.decode(
|
||||||
|
self.token_cache[idx], **self.decode_kwargs
|
||||||
|
)
|
||||||
|
printable_text = text[self.print_len[idx]:]
|
||||||
|
self.token_cache[idx] = []
|
||||||
|
self.print_len[idx] = 0
|
||||||
|
else:
|
||||||
|
printable_text = ""
|
||||||
|
printable_texts.append(printable_text)
|
||||||
|
|
||||||
|
self.next_tokens_are_prompt = True
|
||||||
|
self.on_finalized_text(printable_texts, stream_end=True)
|
||||||
|
|
||||||
|
def on_finalized_text(self, texts: List[str], stream_end: bool = False):
|
||||||
|
self.text_queue.put(texts, timeout=self.timeout)
|
||||||
|
if stream_end:
|
||||||
|
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
||||||
Loading…
Reference in a new issue