[LLM]Reopen autotp generate_stream (#11120)
* reopen autotp generate_stream * fix style error * update
This commit is contained in:
		
							parent
							
								
									1dc680341b
								
							
						
					
					
						commit
						63e95698eb
					
				
					 4 changed files with 414 additions and 69 deletions
				
			
		| 
						 | 
				
			
			@ -58,6 +58,8 @@ If you successfully run the serving, you can get output like this:
 | 
			
		|||
 | 
			
		||||
We can use `curl` to test serving api
 | 
			
		||||
 | 
			
		||||
#### generate()
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
 | 
			
		||||
export http_proxy=
 | 
			
		||||
| 
						 | 
				
			
			@ -77,10 +79,68 @@ And you should get output like this:
 | 
			
		|||
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "generated_text": "What is AI? Artificial intelligence (AI) refers to the development of computer systems able to perform tasks that would normally require human intelligence, such as visual perception, speech",
 | 
			
		||||
  "generate_time": "0.45149803161621094s"
 | 
			
		||||
  "index": 0,
 | 
			
		||||
  "message": {
 | 
			
		||||
    "role": "assistant",
 | 
			
		||||
    "content": "\n\nArtificial intelligence (AI) is a branch of computer science that deals with the creation of intelligent machines that can perform tasks that typically "
 | 
			
		||||
  },
 | 
			
		||||
  "finish_reason": "stop"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
#### generate_stream()
 | 
			
		||||
```bash
 | 
			
		||||
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
 | 
			
		||||
export http_proxy=
 | 
			
		||||
export https_proxy=
 | 
			
		||||
 | 
			
		||||
curl -X 'POST' \
 | 
			
		||||
  'http://127.0.0.1:8000/generate_stream/' \
 | 
			
		||||
  -H 'accept: application/json' \
 | 
			
		||||
  -H 'Content-Type: application/json' \
 | 
			
		||||
  -d '{
 | 
			
		||||
  "prompt": "What is AI?",
 | 
			
		||||
  "n_predict": 32
 | 
			
		||||
}'
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
And you should get output like this:
 | 
			
		||||
```json
 | 
			
		||||
{"index": 0, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
 | 
			
		||||
{"index": 1, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
 | 
			
		||||
{"index": 2, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
			
		||||
{"index": 3, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
			
		||||
{"index": 4, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
			
		||||
{"index": 5, "message": {"role": "assistant", "content": "Artificial "}, "finish_reason": null}
 | 
			
		||||
{"index": 6, "message": {"role": "assistant", "content": "intelligence "}, "finish_reason": null}
 | 
			
		||||
{"index": 7, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
			
		||||
{"index": 8, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
			
		||||
{"index": 9, "message": {"role": "assistant", "content": "(AI) "}, "finish_reason": null}
 | 
			
		||||
{"index": 10, "message": {"role": "assistant", "content": "is "}, "finish_reason": null}
 | 
			
		||||
{"index": 11, "message": {"role": "assistant", "content": "a "}, "finish_reason": null}
 | 
			
		||||
{"index": 12, "message": {"role": "assistant", "content": "branch "}, "finish_reason": null}
 | 
			
		||||
{"index": 13, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
 | 
			
		||||
{"index": 14, "message": {"role": "assistant", "content": "computer "}, "finish_reason": null}
 | 
			
		||||
{"index": 15, "message": {"role": "assistant", "content": "science "}, "finish_reason": null}
 | 
			
		||||
{"index": 16, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
 | 
			
		||||
{"index": 17, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
			
		||||
{"index": 18, "message": {"role": "assistant", "content": "deals "}, "finish_reason": null}
 | 
			
		||||
{"index": 19, "message": {"role": "assistant", "content": "with "}, "finish_reason": null}
 | 
			
		||||
{"index": 20, "message": {"role": "assistant", "content": "the "}, "finish_reason": null}
 | 
			
		||||
{"index": 21, "message": {"role": "assistant", "content": "creation "}, "finish_reason": null}
 | 
			
		||||
{"index": 22, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
 | 
			
		||||
{"index": 23, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
			
		||||
{"index": 24, "message": {"role": "assistant", "content": "intelligent "}, "finish_reason": null}
 | 
			
		||||
{"index": 25, "message": {"role": "assistant", "content": "machines "}, "finish_reason": null}
 | 
			
		||||
{"index": 26, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
 | 
			
		||||
{"index": 27, "message": {"role": "assistant", "content": "can "}, "finish_reason": null}
 | 
			
		||||
{"index": 28, "message": {"role": "assistant", "content": "perform "}, "finish_reason": null}
 | 
			
		||||
{"index": 29, "message": {"role": "assistant", "content": "tasks "}, "finish_reason": null}
 | 
			
		||||
{"index": 30, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
 | 
			
		||||
{"index": 31, "message": {"role": "assistant", "content": "typically "}, "finish_reason": null}
 | 
			
		||||
{"index": 32, "message": {"role": "assistant", "content": null}, "finish_reason": "length"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,17 +18,25 @@ import os
 | 
			
		|||
import torch
 | 
			
		||||
import transformers
 | 
			
		||||
import time
 | 
			
		||||
import json
 | 
			
		||||
import argparse
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
 | 
			
		||||
from fastapi import FastAPI, HTTPException
 | 
			
		||||
from fastapi.responses import StreamingResponse
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
import uvicorn
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from ipex_llm.transformers.streamer import BatchTextIteratorStreamer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import asyncio, uuid
 | 
			
		||||
from collections import deque
 | 
			
		||||
from typing import Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
 | 
			
		||||
| 
						 | 
				
			
			@ -42,6 +50,7 @@ def get_int_from_env(env_keys, default):
 | 
			
		|||
            return val
 | 
			
		||||
    return int(default)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
global max_num_seqs
 | 
			
		||||
global max_num_batched_tokens
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -53,6 +62,19 @@ os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
 | 
			
		|||
 | 
			
		||||
global model, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptRequest(BaseModel):
 | 
			
		||||
    prompt: str
 | 
			
		||||
    n_predict: int = 32
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
rest_req_deque = deque(maxlen=128)
 | 
			
		||||
request_queue: asyncio.Queue = asyncio.Queue()
 | 
			
		||||
result_dict: Dict[str, str] = {}
 | 
			
		||||
streamer_dict = {}
 | 
			
		||||
empty_req = PromptRequest(prompt="", n_predict=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model(model_path, low_bit):
 | 
			
		||||
 | 
			
		||||
    from ipex_llm import optimize_model
 | 
			
		||||
| 
						 | 
				
			
			@ -61,7 +83,9 @@ def load_model(model_path, low_bit):
 | 
			
		|||
    import time
 | 
			
		||||
    import argparse
 | 
			
		||||
 | 
			
		||||
    from transformers import AutoModelForCausalLM  # export AutoModelForCausalLM from transformers so that deepspeed use it
 | 
			
		||||
    from transformers import (
 | 
			
		||||
        AutoModelForCausalLM,
 | 
			
		||||
    )  # export AutoModelForCausalLM from transformers so that deepspeed use it
 | 
			
		||||
    from transformers import LlamaTokenizer, AutoTokenizer
 | 
			
		||||
    import deepspeed
 | 
			
		||||
    from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
 | 
			
		||||
| 
						 | 
				
			
			@ -73,12 +97,14 @@ def load_model(model_path, low_bit):
 | 
			
		|||
    current_accel = CPU_Accelerator()
 | 
			
		||||
    set_accelerator(current_accel)
 | 
			
		||||
    global model, tokenizer
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
        model_path,
 | 
			
		||||
        device_map={"": "cpu"},
 | 
			
		||||
        low_cpu_mem_usage=True,
 | 
			
		||||
        torch_dtype=torch.float16,
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
                                                 use_cache=True)
 | 
			
		||||
        use_cache=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    model = deepspeed.init_inference(
 | 
			
		||||
        model,
 | 
			
		||||
| 
						 | 
				
			
			@ -89,14 +115,14 @@ def load_model(model_path, low_bit):
 | 
			
		|||
 | 
			
		||||
    # Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
 | 
			
		||||
    # Convert the rest of the model into float16 to reduce allreduce traffic
 | 
			
		||||
    model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit).to(torch.float16)
 | 
			
		||||
    model = optimize_model(model.module.to(f"cpu"), low_bit=low_bit).to(torch.float16)
 | 
			
		||||
 | 
			
		||||
    # Next, use XPU as accelerator to speed up inference
 | 
			
		||||
    current_accel = XPU_Accelerator()
 | 
			
		||||
    set_accelerator(current_accel)
 | 
			
		||||
 | 
			
		||||
    # Move model back to xpu
 | 
			
		||||
    model = model.to(f'xpu:{local_rank}')
 | 
			
		||||
    model = model.to(f"xpu:{local_rank}")
 | 
			
		||||
    model = BenchmarkWrapper(model)
 | 
			
		||||
 | 
			
		||||
    # Modify backend related settings
 | 
			
		||||
| 
						 | 
				
			
			@ -104,55 +130,156 @@ def load_model(model_path, low_bit):
 | 
			
		|||
        get_accelerator().set_device(local_rank)
 | 
			
		||||
    dist_backend = get_accelerator().communication_backend_name()
 | 
			
		||||
    import deepspeed.comm.comm
 | 
			
		||||
 | 
			
		||||
    deepspeed.comm.comm.cdb = None
 | 
			
		||||
    from deepspeed.comm.comm import init_distributed
 | 
			
		||||
 | 
			
		||||
    init_distributed()
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        model_path, trust_remote_code=True, padding_side="left"
 | 
			
		||||
    )
 | 
			
		||||
    if tokenizer.pad_token is None:
 | 
			
		||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
 | 
			
		||||
def generate_text(prompt: List[str], n_predict = 32):
 | 
			
		||||
 | 
			
		||||
async def generate_stream_gate(prompt: List[str], n_predict=32, request_ids=[]):
 | 
			
		||||
    while prompt[-1] == "":
 | 
			
		||||
        prompt = prompt[:-1]
 | 
			
		||||
    if isinstance(n_predict, list):
 | 
			
		||||
        n_predict = max(n_predict)
 | 
			
		||||
 | 
			
		||||
    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
 | 
			
		||||
    input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
 | 
			
		||||
    # print(input_ids)
 | 
			
		||||
    attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
 | 
			
		||||
    output = model.generate(input_ids,
 | 
			
		||||
                            attention_mask=attention_mask,
 | 
			
		||||
    input_ids = inputs.input_ids.to(f"xpu:{local_rank}")
 | 
			
		||||
    attention_mask = inputs.attention_mask.to(f"xpu:{local_rank}")
 | 
			
		||||
 | 
			
		||||
    for request_id in request_ids:
 | 
			
		||||
        if request_id not in streamer_dict:
 | 
			
		||||
            streamer_dict[request_id] = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
    streamer = BatchTextIteratorStreamer(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        timeout=600,
 | 
			
		||||
        skip_prompt=True,
 | 
			
		||||
        skip_special_tokens=True,
 | 
			
		||||
        batch_size=len(prompt),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    generated_kwargs = dict(
 | 
			
		||||
        max_new_tokens=n_predict,
 | 
			
		||||
                            use_cache=True)
 | 
			
		||||
        min_new_tokens=n_predict,
 | 
			
		||||
        streamer=streamer,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        do_sample=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def model_generate():
 | 
			
		||||
        model.generate(input_ids, **generated_kwargs)
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
        torch.xpu.synchronize()
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
    t1 = Thread(target=model_generate)
 | 
			
		||||
    t1.start()
 | 
			
		||||
 | 
			
		||||
class PromptRequest(BaseModel):
 | 
			
		||||
    prompt: str
 | 
			
		||||
    n_predict: int = 32  
 | 
			
		||||
    stopped = False
 | 
			
		||||
 | 
			
		||||
    async def put_item(queue, item):
 | 
			
		||||
        await queue.put(item)
 | 
			
		||||
 | 
			
		||||
    for i in range(n_predict):
 | 
			
		||||
        tasks = []
 | 
			
		||||
        try:
 | 
			
		||||
            output_token = next(streamer)
 | 
			
		||||
        except StopIteration:
 | 
			
		||||
            stopped = True
 | 
			
		||||
        for index, request_id in enumerate(request_ids):
 | 
			
		||||
            task = asyncio.create_task(
 | 
			
		||||
                put_item(
 | 
			
		||||
                    streamer_dict[request_id],
 | 
			
		||||
                    (0 if stopped else n_predict - 1 - i, output_token[index]),
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            tasks.append(task)
 | 
			
		||||
        await asyncio.gather(*tasks)
 | 
			
		||||
        if stopped:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
empty_req = PromptRequest(prompt="", n_predict=0)
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
from collections import deque
 | 
			
		||||
rest_req_deque = deque(maxlen=128)
 | 
			
		||||
request_queue: asyncio.Queue = asyncio.Queue()
 | 
			
		||||
result_dict: Dict[str, str] = {}
 | 
			
		||||
 | 
			
		||||
async def stream_generator(token_queue, request_id):
 | 
			
		||||
    index = 0
 | 
			
		||||
    while True:
 | 
			
		||||
        if not token_queue.empty():
 | 
			
		||||
            remain, token = await token_queue.get()
 | 
			
		||||
            response = {
 | 
			
		||||
                "index": index,
 | 
			
		||||
                "message": {"role": "assistant", "content": token},
 | 
			
		||||
                "finish_reason": None,
 | 
			
		||||
            }
 | 
			
		||||
            yield json.dumps(response) + "\n"
 | 
			
		||||
            index = index + 1
 | 
			
		||||
            if remain == 0:
 | 
			
		||||
                response = {
 | 
			
		||||
                    "index": index,
 | 
			
		||||
                    "message": {"role": "assistant", "content": None},
 | 
			
		||||
                    "finish_reason": "length",
 | 
			
		||||
                }
 | 
			
		||||
                yield json.dumps(response) + "\n"
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            await asyncio.sleep(0)
 | 
			
		||||
    streamer_dict.pop(request_id, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def generator(token_queue, request_id):
 | 
			
		||||
    while True:
 | 
			
		||||
        if not token_queue.empty():
 | 
			
		||||
            remain, token = await token_queue.get()
 | 
			
		||||
            yield token
 | 
			
		||||
            if remain == 0:
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            await asyncio.sleep(0)
 | 
			
		||||
    streamer_dict.pop(request_id, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/generate/")
 | 
			
		||||
async def generate(prompt_request: PromptRequest):
 | 
			
		||||
    request_id = str(uuid.uuid4())
 | 
			
		||||
    await request_queue.put((request_id, prompt_request))
 | 
			
		||||
    while True:
 | 
			
		||||
        await asyncio.sleep(0.1)
 | 
			
		||||
        if request_id in result_dict:
 | 
			
		||||
            output_str = result_dict.pop(request_id)
 | 
			
		||||
            return {"generated_text": output_str}
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        if request_id in streamer_dict:
 | 
			
		||||
            output_str = []
 | 
			
		||||
            token_queue = streamer_dict[request_id]
 | 
			
		||||
            async for item in generator(token_queue, request_id):
 | 
			
		||||
                output_str.append(item)
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                "index": 0,
 | 
			
		||||
                "message": {
 | 
			
		||||
                    "role": "assistant",
 | 
			
		||||
                    "content": "".join(output_str),
 | 
			
		||||
                },
 | 
			
		||||
                "finish_reason": "stop",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/generate_stream/")
 | 
			
		||||
async def generate_stream(prompt_request: PromptRequest):
 | 
			
		||||
    request_id = str(uuid.uuid4()) + "stream"
 | 
			
		||||
    await request_queue.put((request_id, prompt_request))
 | 
			
		||||
    while True:
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        if request_id in streamer_dict:
 | 
			
		||||
            token_queue = streamer_dict[request_id]
 | 
			
		||||
 | 
			
		||||
            return StreamingResponse(
 | 
			
		||||
                stream_generator(token_queue, request_id), media_type="application/json"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def process_requests():
 | 
			
		||||
| 
						 | 
				
			
			@ -164,7 +291,9 @@ async def process_requests():
 | 
			
		|||
            while rest_req_deque:
 | 
			
		||||
                request_id, rest_request = rest_req_deque.popleft()
 | 
			
		||||
                prompt = rest_request.prompt
 | 
			
		||||
                cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
 | 
			
		||||
                cur_prompt_len = tokenizer(
 | 
			
		||||
                    prompt_request.prompt, return_tensors="pt"
 | 
			
		||||
                ).input_ids.size(1)
 | 
			
		||||
                cur_batched_tokens += cur_prompt_len
 | 
			
		||||
                if cur_batched_tokens > max_num_batched_tokens:
 | 
			
		||||
                    cur_batched_tokens -= cur_prompt_len
 | 
			
		||||
| 
						 | 
				
			
			@ -179,9 +308,9 @@ async def process_requests():
 | 
			
		|||
                if request_queue.empty():
 | 
			
		||||
                    break
 | 
			
		||||
                request_id, prompt_request = await request_queue.get()
 | 
			
		||||
                # import pdb
 | 
			
		||||
                # pdb.set_trace()
 | 
			
		||||
                cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
 | 
			
		||||
                cur_prompt_len = tokenizer(
 | 
			
		||||
                    prompt_request.prompt, return_tensors="pt"
 | 
			
		||||
                ).input_ids.size(1)
 | 
			
		||||
                cur_batched_tokens += cur_prompt_len
 | 
			
		||||
                if cur_batched_tokens > max_num_batched_tokens:
 | 
			
		||||
                    cur_batched_tokens -= cur_prompt_len
 | 
			
		||||
| 
						 | 
				
			
			@ -193,21 +322,28 @@ async def process_requests():
 | 
			
		|||
        if local_rank == 0 and prompt_requests:
 | 
			
		||||
            object_list = prompt_requests
 | 
			
		||||
            if len(object_list) < max_num_seqs:
 | 
			
		||||
                object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
 | 
			
		||||
            logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
 | 
			
		||||
                object_list = object_list + [empty_req] * (
 | 
			
		||||
                    max_num_seqs - len(object_list)
 | 
			
		||||
                )
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}"
 | 
			
		||||
            )
 | 
			
		||||
            dist.broadcast_object_list(object_list, src=0)
 | 
			
		||||
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
            outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
 | 
			
		||||
            await generate_stream_gate(
 | 
			
		||||
                [req.prompt for req in object_list],
 | 
			
		||||
                [req.n_predict for req in object_list],
 | 
			
		||||
                request_ids,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            generate_time = time.time() - start_time
 | 
			
		||||
            outputs = outputs.cpu()
 | 
			
		||||
            output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
 | 
			
		||||
            output_strs = output_strs[:len(prompt_requests)]
 | 
			
		||||
 | 
			
		||||
            for request_id, output_str in zip(request_ids, output_strs):
 | 
			
		||||
                result_dict[request_id] = output_str
 | 
			
		||||
            logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        await asyncio.sleep(0.1)
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.on_event("startup")
 | 
			
		||||
| 
						 | 
				
			
			@ -216,19 +352,44 @@ async def startup_event():
 | 
			
		|||
        asyncio.create_task(process_requests())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
 | 
			
		||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
			
		||||
                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
 | 
			
		||||
                             ', or the path to the huggingface checkpoint folder')
 | 
			
		||||
    parser.add_argument('--low-bit', type=str, default='sym_int4',
 | 
			
		||||
                    help='The quantization type the model will convert to.')
 | 
			
		||||
    parser.add_argument('--port', type=int, default=8000,
 | 
			
		||||
                    help='The port number on which the server will run.')
 | 
			
		||||
    parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
 | 
			
		||||
                    help='Max tokens can be batched by this service.')
 | 
			
		||||
    parser.add_argument('--max-num-seqs', type=int, default=8,
 | 
			
		||||
                    help='Max requests can be batched by this service.')
 | 
			
		||||
async def main():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--repo-id-or-model-path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="meta-llama/Llama-2-7b-chat-hf",
 | 
			
		||||
        help="The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded"
 | 
			
		||||
        ", or the path to the huggingface checkpoint folder",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--low-bit",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="sym_int4",
 | 
			
		||||
        help="The quantization type the model will convert to.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--port",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=8000,
 | 
			
		||||
        help="The port number on which the server will run.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--max-num-batched-tokens",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=4096,
 | 
			
		||||
        help="Max tokens can be batched by this service.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--max-num-seqs",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=8,
 | 
			
		||||
        help="Max requests can be batched by this service.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    global max_num_seqs
 | 
			
		||||
    global max_num_batched_tokens
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
| 
						 | 
				
			
			@ -236,10 +397,21 @@ if __name__ == "__main__":
 | 
			
		|||
    max_num_seqs = args.max_num_seqs
 | 
			
		||||
    max_num_batched_tokens = args.max_num_batched_tokens
 | 
			
		||||
    load_model(model_path, low_bit)
 | 
			
		||||
 | 
			
		||||
    config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
 | 
			
		||||
    server = uvicorn.Server(config)
 | 
			
		||||
 | 
			
		||||
    if local_rank == 0:
 | 
			
		||||
        uvicorn.run(app, host="0.0.0.0", port=args.port)
 | 
			
		||||
        await server.serve()
 | 
			
		||||
    else:
 | 
			
		||||
        while True:
 | 
			
		||||
            object_list = [None] * max_num_seqs
 | 
			
		||||
            dist.broadcast_object_list(object_list, src=0)
 | 
			
		||||
            output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
 | 
			
		||||
            await generate_stream_gate(
 | 
			
		||||
                [req.prompt for req in object_list],
 | 
			
		||||
                [req.n_predict for req in object_list],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(main())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,5 +33,4 @@ export TORCH_LLM_ALLREDUCE=0
 | 
			
		|||
export WORLD_SIZE=2
 | 
			
		||||
 | 
			
		||||
mpirun -np $NUM_GPUS --prepend-rank \
 | 
			
		||||
        python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
 | 
			
		||||
 | 
			
		||||
        python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'fp8' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
 | 
			
		||||
							
								
								
									
										114
									
								
								python/llm/src/ipex_llm/transformers/streamer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								python/llm/src/ipex_llm/transformers/streamer.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,114 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from typing import Optional, List
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import TextIteratorStreamer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BatchTextIteratorStreamer(TextIteratorStreamer):
 | 
			
		||||
    """
 | 
			
		||||
    A specialized version of TextIteratorStreamer that handles text streams in batches, providing
 | 
			
		||||
    an efficient way to process large volumes of text data asynchronously. This class is designed
 | 
			
		||||
    to aggregate multiple texts into batches, making it ideal for applications that need to
 | 
			
		||||
    perform batch operations on streamed text data, such as bulk text processing or machine
 | 
			
		||||
    learning model inference in an interactive environment.
 | 
			
		||||
 | 
			
		||||
        Parameters:
 | 
			
		||||
                tokenizer (`AutoTokenizer`):
 | 
			
		||||
                        The tokenized used to decode the tokens.
 | 
			
		||||
                skip_prompt (`bool`, *optional*, defaults to `False`):
 | 
			
		||||
                        Whether to skip the prompt to `.generate()` or not.
 | 
			
		||||
                timeout (`float`, *optional*):
 | 
			
		||||
                        The timeout for the text queue. If `None`, the queue will
 | 
			
		||||
                        block indefinitely. Useful to handle exceptions
 | 
			
		||||
                        in `.generate()`, when it is called in a separate thread.
 | 
			
		||||
                decode_kwargs (`dict`, *optional*):
 | 
			
		||||
                        Additional keyword arguments to pass to the tokenizer's `decode` method.
 | 
			
		||||
                batch_size(`int`)
 | 
			
		||||
                        The size of the batches to process. This parameter must be specified and
 | 
			
		||||
                        determines how many texts are processed together as a single batch.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        batch_size: int,
 | 
			
		||||
        tokenizer: "AutoTokenizer",
 | 
			
		||||
        skip_prompt: bool = False,
 | 
			
		||||
        timeout: Optional[float] = None,
 | 
			
		||||
        **decode_kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
 | 
			
		||||
        self.batch_size = batch_size
 | 
			
		||||
        self.token_cache = [[] for _ in range(batch_size)]
 | 
			
		||||
        self.print_len = [0 for _ in range(batch_size)]
 | 
			
		||||
        self.generate_exception = None
 | 
			
		||||
 | 
			
		||||
    def put(self, value):
 | 
			
		||||
        if len(value.shape) != 2:
 | 
			
		||||
            value = torch.reshape(
 | 
			
		||||
                value, (self.batch_size, value.shape[0] // self.batch_size)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.skip_prompt and self.next_tokens_are_prompt:
 | 
			
		||||
            self.next_tokens_are_prompt = False
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        printable_texts = list()
 | 
			
		||||
        for idx in range(self.batch_size):
 | 
			
		||||
            self.token_cache[idx].extend(value[idx].tolist())
 | 
			
		||||
            text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
 | 
			
		||||
 | 
			
		||||
            if text.endswith("\n"):
 | 
			
		||||
                printable_text = text[self.print_len[idx]:]
 | 
			
		||||
                self.token_cache[idx] = []
 | 
			
		||||
                self.print_len[idx] = 0
 | 
			
		||||
                # If the last token is a CJK character, we print the characters.
 | 
			
		||||
            elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
 | 
			
		||||
                printable_text = text[self.print_len[idx]:]
 | 
			
		||||
                self.print_len[idx] += len(printable_text)
 | 
			
		||||
            else:
 | 
			
		||||
                printable_text = text[self.print_len[idx]:text.rfind(" ") + 1]
 | 
			
		||||
                self.print_len[idx] += len(printable_text)
 | 
			
		||||
            printable_texts.append(printable_text)
 | 
			
		||||
 | 
			
		||||
        self.on_finalized_text(printable_texts)
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
        printable_texts = list()
 | 
			
		||||
        for idx in range(self.batch_size):
 | 
			
		||||
            if len(self.token_cache[idx]) > 0:
 | 
			
		||||
                text = self.tokenizer.decode(
 | 
			
		||||
                    self.token_cache[idx], **self.decode_kwargs
 | 
			
		||||
                )
 | 
			
		||||
                printable_text = text[self.print_len[idx]:]
 | 
			
		||||
                self.token_cache[idx] = []
 | 
			
		||||
                self.print_len[idx] = 0
 | 
			
		||||
            else:
 | 
			
		||||
                printable_text = ""
 | 
			
		||||
            printable_texts.append(printable_text)
 | 
			
		||||
 | 
			
		||||
        self.next_tokens_are_prompt = True
 | 
			
		||||
        self.on_finalized_text(printable_texts, stream_end=True)
 | 
			
		||||
 | 
			
		||||
    def on_finalized_text(self, texts: List[str], stream_end: bool = False):
 | 
			
		||||
        self.text_queue.put(texts, timeout=self.timeout)
 | 
			
		||||
        if stream_end:
 | 
			
		||||
            self.text_queue.put(self.stop_signal, timeout=self.timeout)
 | 
			
		||||
		Loading…
	
		Reference in a new issue