[LLM]Reopen autotp generate_stream (#11120)
* reopen autotp generate_stream * fix style error * update
This commit is contained in:
		
							parent
							
								
									1dc680341b
								
							
						
					
					
						commit
						63e95698eb
					
				
					 4 changed files with 414 additions and 69 deletions
				
			
		| 
						 | 
					@ -58,6 +58,8 @@ If you successfully run the serving, you can get output like this:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
We can use `curl` to test serving api
 | 
					We can use `curl` to test serving api
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### generate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
 | 
					# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
 | 
				
			||||||
export http_proxy=
 | 
					export http_proxy=
 | 
				
			||||||
| 
						 | 
					@ -77,10 +79,68 @@ And you should get output like this:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```json
 | 
					```json
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
  "generated_text": "What is AI? Artificial intelligence (AI) refers to the development of computer systems able to perform tasks that would normally require human intelligence, such as visual perception, speech",
 | 
					  "index": 0,
 | 
				
			||||||
  "generate_time": "0.45149803161621094s"
 | 
					  "message": {
 | 
				
			||||||
 | 
					    "role": "assistant",
 | 
				
			||||||
 | 
					    "content": "\n\nArtificial intelligence (AI) is a branch of computer science that deals with the creation of intelligent machines that can perform tasks that typically "
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "finish_reason": "stop"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					#### generate_stream()
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
 | 
				
			||||||
 | 
					export http_proxy=
 | 
				
			||||||
 | 
					export https_proxy=
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					curl -X 'POST' \
 | 
				
			||||||
 | 
					  'http://127.0.0.1:8000/generate_stream/' \
 | 
				
			||||||
 | 
					  -H 'accept: application/json' \
 | 
				
			||||||
 | 
					  -H 'Content-Type: application/json' \
 | 
				
			||||||
 | 
					  -d '{
 | 
				
			||||||
 | 
					  "prompt": "What is AI?",
 | 
				
			||||||
 | 
					  "n_predict": 32
 | 
				
			||||||
 | 
					}'
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					And you should get output like this:
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{"index": 0, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 1, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 2, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 3, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 4, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 5, "message": {"role": "assistant", "content": "Artificial "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 6, "message": {"role": "assistant", "content": "intelligence "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 7, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 8, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 9, "message": {"role": "assistant", "content": "(AI) "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 10, "message": {"role": "assistant", "content": "is "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 11, "message": {"role": "assistant", "content": "a "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 12, "message": {"role": "assistant", "content": "branch "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 13, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 14, "message": {"role": "assistant", "content": "computer "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 15, "message": {"role": "assistant", "content": "science "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 16, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 17, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 18, "message": {"role": "assistant", "content": "deals "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 19, "message": {"role": "assistant", "content": "with "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 20, "message": {"role": "assistant", "content": "the "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 21, "message": {"role": "assistant", "content": "creation "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 22, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 23, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 24, "message": {"role": "assistant", "content": "intelligent "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 25, "message": {"role": "assistant", "content": "machines "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 26, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 27, "message": {"role": "assistant", "content": "can "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 28, "message": {"role": "assistant", "content": "perform "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 29, "message": {"role": "assistant", "content": "tasks "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 30, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 31, "message": {"role": "assistant", "content": "typically "}, "finish_reason": null}
 | 
				
			||||||
 | 
					{"index": 32, "message": {"role": "assistant", "content": null}, "finish_reason": "length"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.
 | 
					**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,17 +18,25 @@ import os
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import torch.distributed as dist
 | 
					import torch.distributed as dist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from fastapi import FastAPI, HTTPException
 | 
					from fastapi import FastAPI, HTTPException
 | 
				
			||||||
 | 
					from fastapi.responses import StreamingResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pydantic import BaseModel
 | 
					from pydantic import BaseModel
 | 
				
			||||||
import uvicorn
 | 
					import uvicorn
 | 
				
			||||||
 | 
					from threading import Thread
 | 
				
			||||||
 | 
					from ipex_llm.transformers.streamer import BatchTextIteratorStreamer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import asyncio, uuid
 | 
					import asyncio, uuid
 | 
				
			||||||
 | 
					from collections import deque
 | 
				
			||||||
from typing import Dict, List, Optional
 | 
					from typing import Dict, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
 | 
					from ipex_llm.utils.benchmark_util import BenchmarkWrapper
 | 
				
			||||||
| 
						 | 
					@ -42,17 +50,31 @@ def get_int_from_env(env_keys, default):
 | 
				
			||||||
            return val
 | 
					            return val
 | 
				
			||||||
    return int(default)
 | 
					    return int(default)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
global max_num_seqs
 | 
					global max_num_seqs
 | 
				
			||||||
global max_num_batched_tokens
 | 
					global max_num_batched_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
 | 
					local_rank = get_int_from_env(["LOCAL_RANK", "PMI_RANK"], "0")
 | 
				
			||||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
 | 
					world_size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE"], "1")
 | 
				
			||||||
os.environ["RANK"] = str(local_rank)
 | 
					os.environ["RANK"] = str(local_rank)
 | 
				
			||||||
os.environ["WORLD_SIZE"] = str(world_size)
 | 
					os.environ["WORLD_SIZE"] = str(world_size)
 | 
				
			||||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
 | 
					os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
global model, tokenizer
 | 
					global model, tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PromptRequest(BaseModel):
 | 
				
			||||||
 | 
					    prompt: str
 | 
				
			||||||
 | 
					    n_predict: int = 32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					rest_req_deque = deque(maxlen=128)
 | 
				
			||||||
 | 
					request_queue: asyncio.Queue = asyncio.Queue()
 | 
				
			||||||
 | 
					result_dict: Dict[str, str] = {}
 | 
				
			||||||
 | 
					streamer_dict = {}
 | 
				
			||||||
 | 
					empty_req = PromptRequest(prompt="", n_predict=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_model(model_path, low_bit):
 | 
					def load_model(model_path, low_bit):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    from ipex_llm import optimize_model
 | 
					    from ipex_llm import optimize_model
 | 
				
			||||||
| 
						 | 
					@ -61,7 +83,9 @@ def load_model(model_path, low_bit):
 | 
				
			||||||
    import time
 | 
					    import time
 | 
				
			||||||
    import argparse
 | 
					    import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    from transformers import AutoModelForCausalLM  # export AutoModelForCausalLM from transformers so that deepspeed use it
 | 
					    from transformers import (
 | 
				
			||||||
 | 
					        AutoModelForCausalLM,
 | 
				
			||||||
 | 
					    )  # export AutoModelForCausalLM from transformers so that deepspeed use it
 | 
				
			||||||
    from transformers import LlamaTokenizer, AutoTokenizer
 | 
					    from transformers import LlamaTokenizer, AutoTokenizer
 | 
				
			||||||
    import deepspeed
 | 
					    import deepspeed
 | 
				
			||||||
    from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
 | 
					    from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
 | 
				
			||||||
| 
						 | 
					@ -73,12 +97,14 @@ def load_model(model_path, low_bit):
 | 
				
			||||||
    current_accel = CPU_Accelerator()
 | 
					    current_accel = CPU_Accelerator()
 | 
				
			||||||
    set_accelerator(current_accel)
 | 
					    set_accelerator(current_accel)
 | 
				
			||||||
    global model, tokenizer
 | 
					    global model, tokenizer
 | 
				
			||||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
					    model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					        model_path,
 | 
				
			||||||
        device_map={"": "cpu"},
 | 
					        device_map={"": "cpu"},
 | 
				
			||||||
        low_cpu_mem_usage=True,
 | 
					        low_cpu_mem_usage=True,
 | 
				
			||||||
        torch_dtype=torch.float16,
 | 
					        torch_dtype=torch.float16,
 | 
				
			||||||
        trust_remote_code=True,
 | 
					        trust_remote_code=True,
 | 
				
			||||||
                                                 use_cache=True)
 | 
					        use_cache=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = deepspeed.init_inference(
 | 
					    model = deepspeed.init_inference(
 | 
				
			||||||
        model,
 | 
					        model,
 | 
				
			||||||
| 
						 | 
					@ -89,14 +115,14 @@ def load_model(model_path, low_bit):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
 | 
					    # Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
 | 
				
			||||||
    # Convert the rest of the model into float16 to reduce allreduce traffic
 | 
					    # Convert the rest of the model into float16 to reduce allreduce traffic
 | 
				
			||||||
    model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit).to(torch.float16)
 | 
					    model = optimize_model(model.module.to(f"cpu"), low_bit=low_bit).to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Next, use XPU as accelerator to speed up inference
 | 
					    # Next, use XPU as accelerator to speed up inference
 | 
				
			||||||
    current_accel = XPU_Accelerator()
 | 
					    current_accel = XPU_Accelerator()
 | 
				
			||||||
    set_accelerator(current_accel)
 | 
					    set_accelerator(current_accel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Move model back to xpu
 | 
					    # Move model back to xpu
 | 
				
			||||||
    model = model.to(f'xpu:{local_rank}')
 | 
					    model = model.to(f"xpu:{local_rank}")
 | 
				
			||||||
    model = BenchmarkWrapper(model)
 | 
					    model = BenchmarkWrapper(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Modify backend related settings
 | 
					    # Modify backend related settings
 | 
				
			||||||
| 
						 | 
					@ -104,55 +130,156 @@ def load_model(model_path, low_bit):
 | 
				
			||||||
        get_accelerator().set_device(local_rank)
 | 
					        get_accelerator().set_device(local_rank)
 | 
				
			||||||
    dist_backend = get_accelerator().communication_backend_name()
 | 
					    dist_backend = get_accelerator().communication_backend_name()
 | 
				
			||||||
    import deepspeed.comm.comm
 | 
					    import deepspeed.comm.comm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    deepspeed.comm.comm.cdb = None
 | 
					    deepspeed.comm.comm.cdb = None
 | 
				
			||||||
    from deepspeed.comm.comm import init_distributed
 | 
					    from deepspeed.comm.comm import init_distributed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    init_distributed()
 | 
					    init_distributed()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load tokenizer
 | 
					    # Load tokenizer
 | 
				
			||||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
					    tokenizer = AutoTokenizer.from_pretrained(
 | 
				
			||||||
 | 
					        model_path, trust_remote_code=True, padding_side="left"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    if tokenizer.pad_token is None:
 | 
					    if tokenizer.pad_token is None:
 | 
				
			||||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
					        tokenizer.pad_token = tokenizer.eos_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_text(prompt: List[str], n_predict = 32):
 | 
					
 | 
				
			||||||
 | 
					async def generate_stream_gate(prompt: List[str], n_predict=32, request_ids=[]):
 | 
				
			||||||
    while prompt[-1] == "":
 | 
					    while prompt[-1] == "":
 | 
				
			||||||
        prompt = prompt[:-1]
 | 
					        prompt = prompt[:-1]
 | 
				
			||||||
    if isinstance(n_predict, list):
 | 
					    if isinstance(n_predict, list):
 | 
				
			||||||
        n_predict = max(n_predict)
 | 
					        n_predict = max(n_predict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
 | 
					    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
 | 
				
			||||||
    input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
 | 
					    input_ids = inputs.input_ids.to(f"xpu:{local_rank}")
 | 
				
			||||||
    # print(input_ids)
 | 
					    attention_mask = inputs.attention_mask.to(f"xpu:{local_rank}")
 | 
				
			||||||
    attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
 | 
					
 | 
				
			||||||
    output = model.generate(input_ids,
 | 
					    for request_id in request_ids:
 | 
				
			||||||
                            attention_mask=attention_mask,
 | 
					        if request_id not in streamer_dict:
 | 
				
			||||||
 | 
					            streamer_dict[request_id] = asyncio.Queue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    streamer = BatchTextIteratorStreamer(
 | 
				
			||||||
 | 
					        tokenizer=tokenizer,
 | 
				
			||||||
 | 
					        timeout=600,
 | 
				
			||||||
 | 
					        skip_prompt=True,
 | 
				
			||||||
 | 
					        skip_special_tokens=True,
 | 
				
			||||||
 | 
					        batch_size=len(prompt),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    generated_kwargs = dict(
 | 
				
			||||||
        max_new_tokens=n_predict,
 | 
					        max_new_tokens=n_predict,
 | 
				
			||||||
                            use_cache=True)
 | 
					        min_new_tokens=n_predict,
 | 
				
			||||||
 | 
					        streamer=streamer,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        do_sample=False,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def model_generate():
 | 
				
			||||||
 | 
					        model.generate(input_ids, **generated_kwargs)
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
        torch.xpu.synchronize()
 | 
					        torch.xpu.synchronize()
 | 
				
			||||||
    return output
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    t1 = Thread(target=model_generate)
 | 
				
			||||||
 | 
					    t1.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PromptRequest(BaseModel):
 | 
					    stopped = False
 | 
				
			||||||
    prompt: str
 | 
					
 | 
				
			||||||
    n_predict: int = 32  
 | 
					    async def put_item(queue, item):
 | 
				
			||||||
 | 
					        await queue.put(item)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i in range(n_predict):
 | 
				
			||||||
 | 
					        tasks = []
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            output_token = next(streamer)
 | 
				
			||||||
 | 
					        except StopIteration:
 | 
				
			||||||
 | 
					            stopped = True
 | 
				
			||||||
 | 
					        for index, request_id in enumerate(request_ids):
 | 
				
			||||||
 | 
					            task = asyncio.create_task(
 | 
				
			||||||
 | 
					                put_item(
 | 
				
			||||||
 | 
					                    streamer_dict[request_id],
 | 
				
			||||||
 | 
					                    (0 if stopped else n_predict - 1 - i, output_token[index]),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            tasks.append(task)
 | 
				
			||||||
 | 
					        await asyncio.gather(*tasks)
 | 
				
			||||||
 | 
					        if stopped:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
empty_req = PromptRequest(prompt="", n_predict=0)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
app = FastAPI()
 | 
					app = FastAPI()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from collections import deque
 | 
					
 | 
				
			||||||
rest_req_deque = deque(maxlen=128)
 | 
					async def stream_generator(token_queue, request_id):
 | 
				
			||||||
request_queue: asyncio.Queue = asyncio.Queue()
 | 
					    index = 0
 | 
				
			||||||
result_dict: Dict[str, str] = {}
 | 
					    while True:
 | 
				
			||||||
 | 
					        if not token_queue.empty():
 | 
				
			||||||
 | 
					            remain, token = await token_queue.get()
 | 
				
			||||||
 | 
					            response = {
 | 
				
			||||||
 | 
					                "index": index,
 | 
				
			||||||
 | 
					                "message": {"role": "assistant", "content": token},
 | 
				
			||||||
 | 
					                "finish_reason": None,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            yield json.dumps(response) + "\n"
 | 
				
			||||||
 | 
					            index = index + 1
 | 
				
			||||||
 | 
					            if remain == 0:
 | 
				
			||||||
 | 
					                response = {
 | 
				
			||||||
 | 
					                    "index": index,
 | 
				
			||||||
 | 
					                    "message": {"role": "assistant", "content": None},
 | 
				
			||||||
 | 
					                    "finish_reason": "length",
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                yield json.dumps(response) + "\n"
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            await asyncio.sleep(0)
 | 
				
			||||||
 | 
					    streamer_dict.pop(request_id, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def generator(token_queue, request_id):
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        if not token_queue.empty():
 | 
				
			||||||
 | 
					            remain, token = await token_queue.get()
 | 
				
			||||||
 | 
					            yield token
 | 
				
			||||||
 | 
					            if remain == 0:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            await asyncio.sleep(0)
 | 
				
			||||||
 | 
					    streamer_dict.pop(request_id, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/generate/")
 | 
					@app.post("/generate/")
 | 
				
			||||||
async def generate(prompt_request: PromptRequest):
 | 
					async def generate(prompt_request: PromptRequest):
 | 
				
			||||||
    request_id = str(uuid.uuid4())
 | 
					    request_id = str(uuid.uuid4())
 | 
				
			||||||
    await request_queue.put((request_id, prompt_request))
 | 
					    await request_queue.put((request_id, prompt_request))
 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        await asyncio.sleep(0.1)
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
        if request_id in result_dict:
 | 
					        if request_id in streamer_dict:
 | 
				
			||||||
            output_str = result_dict.pop(request_id)
 | 
					            output_str = []
 | 
				
			||||||
            return {"generated_text": output_str}
 | 
					            token_queue = streamer_dict[request_id]
 | 
				
			||||||
 | 
					            async for item in generator(token_queue, request_id):
 | 
				
			||||||
 | 
					                output_str.append(item)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return {
 | 
				
			||||||
 | 
					                "index": 0,
 | 
				
			||||||
 | 
					                "message": {
 | 
				
			||||||
 | 
					                    "role": "assistant",
 | 
				
			||||||
 | 
					                    "content": "".join(output_str),
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                "finish_reason": "stop",
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/generate_stream/")
 | 
				
			||||||
 | 
					async def generate_stream(prompt_request: PromptRequest):
 | 
				
			||||||
 | 
					    request_id = str(uuid.uuid4()) + "stream"
 | 
				
			||||||
 | 
					    await request_queue.put((request_id, prompt_request))
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					        if request_id in streamer_dict:
 | 
				
			||||||
 | 
					            token_queue = streamer_dict[request_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return StreamingResponse(
 | 
				
			||||||
 | 
					                stream_generator(token_queue, request_id), media_type="application/json"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def process_requests():
 | 
					async def process_requests():
 | 
				
			||||||
| 
						 | 
					@ -164,7 +291,9 @@ async def process_requests():
 | 
				
			||||||
            while rest_req_deque:
 | 
					            while rest_req_deque:
 | 
				
			||||||
                request_id, rest_request = rest_req_deque.popleft()
 | 
					                request_id, rest_request = rest_req_deque.popleft()
 | 
				
			||||||
                prompt = rest_request.prompt
 | 
					                prompt = rest_request.prompt
 | 
				
			||||||
                cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
 | 
					                cur_prompt_len = tokenizer(
 | 
				
			||||||
 | 
					                    prompt_request.prompt, return_tensors="pt"
 | 
				
			||||||
 | 
					                ).input_ids.size(1)
 | 
				
			||||||
                cur_batched_tokens += cur_prompt_len
 | 
					                cur_batched_tokens += cur_prompt_len
 | 
				
			||||||
                if cur_batched_tokens > max_num_batched_tokens:
 | 
					                if cur_batched_tokens > max_num_batched_tokens:
 | 
				
			||||||
                    cur_batched_tokens -= cur_prompt_len
 | 
					                    cur_batched_tokens -= cur_prompt_len
 | 
				
			||||||
| 
						 | 
					@ -179,9 +308,9 @@ async def process_requests():
 | 
				
			||||||
                if request_queue.empty():
 | 
					                if request_queue.empty():
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
                request_id, prompt_request = await request_queue.get()
 | 
					                request_id, prompt_request = await request_queue.get()
 | 
				
			||||||
                # import pdb
 | 
					                cur_prompt_len = tokenizer(
 | 
				
			||||||
                # pdb.set_trace()
 | 
					                    prompt_request.prompt, return_tensors="pt"
 | 
				
			||||||
                cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
 | 
					                ).input_ids.size(1)
 | 
				
			||||||
                cur_batched_tokens += cur_prompt_len
 | 
					                cur_batched_tokens += cur_prompt_len
 | 
				
			||||||
                if cur_batched_tokens > max_num_batched_tokens:
 | 
					                if cur_batched_tokens > max_num_batched_tokens:
 | 
				
			||||||
                    cur_batched_tokens -= cur_prompt_len
 | 
					                    cur_batched_tokens -= cur_prompt_len
 | 
				
			||||||
| 
						 | 
					@ -193,21 +322,28 @@ async def process_requests():
 | 
				
			||||||
        if local_rank == 0 and prompt_requests:
 | 
					        if local_rank == 0 and prompt_requests:
 | 
				
			||||||
            object_list = prompt_requests
 | 
					            object_list = prompt_requests
 | 
				
			||||||
            if len(object_list) < max_num_seqs:
 | 
					            if len(object_list) < max_num_seqs:
 | 
				
			||||||
                object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
 | 
					                object_list = object_list + [empty_req] * (
 | 
				
			||||||
            logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
 | 
					                    max_num_seqs - len(object_list)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            logger.info(
 | 
				
			||||||
 | 
					                f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            dist.broadcast_object_list(object_list, src=0)
 | 
					            dist.broadcast_object_list(object_list, src=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            start_time = time.time()
 | 
					            start_time = time.time()
 | 
				
			||||||
            outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
 | 
					            await generate_stream_gate(
 | 
				
			||||||
 | 
					                [req.prompt for req in object_list],
 | 
				
			||||||
 | 
					                [req.n_predict for req in object_list],
 | 
				
			||||||
 | 
					                request_ids,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            generate_time = time.time() - start_time
 | 
					            generate_time = time.time() - start_time
 | 
				
			||||||
            outputs = outputs.cpu()
 | 
					 | 
				
			||||||
            output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
 | 
					 | 
				
			||||||
            output_strs = output_strs[:len(prompt_requests)]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for request_id, output_str in zip(request_ids, output_strs):
 | 
					            logger.info(
 | 
				
			||||||
                result_dict[request_id] = output_str
 | 
					                f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}"
 | 
				
			||||||
            logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await asyncio.sleep(0.1)
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.on_event("startup")
 | 
					@app.on_event("startup")
 | 
				
			||||||
| 
						 | 
					@ -216,19 +352,44 @@ async def startup_event():
 | 
				
			||||||
        asyncio.create_task(process_requests())
 | 
					        asyncio.create_task(process_requests())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					async def main():
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
 | 
					    parser = argparse.ArgumentParser(
 | 
				
			||||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
					        description="Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP"
 | 
				
			||||||
                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
 | 
					    )
 | 
				
			||||||
                             ', or the path to the huggingface checkpoint folder')
 | 
					    parser.add_argument(
 | 
				
			||||||
    parser.add_argument('--low-bit', type=str, default='sym_int4',
 | 
					        "--repo-id-or-model-path",
 | 
				
			||||||
                    help='The quantization type the model will convert to.')
 | 
					        type=str,
 | 
				
			||||||
    parser.add_argument('--port', type=int, default=8000,
 | 
					        default="meta-llama/Llama-2-7b-chat-hf",
 | 
				
			||||||
                    help='The port number on which the server will run.')
 | 
					        help="The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded"
 | 
				
			||||||
    parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
 | 
					        ", or the path to the huggingface checkpoint folder",
 | 
				
			||||||
                    help='Max tokens can be batched by this service.')
 | 
					    )
 | 
				
			||||||
    parser.add_argument('--max-num-seqs', type=int, default=8,
 | 
					    parser.add_argument(
 | 
				
			||||||
                    help='Max requests can be batched by this service.')
 | 
					        "--low-bit",
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default="sym_int4",
 | 
				
			||||||
 | 
					        help="The quantization type the model will convert to.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--port",
 | 
				
			||||||
 | 
					        type=int,
 | 
				
			||||||
 | 
					        default=8000,
 | 
				
			||||||
 | 
					        help="The port number on which the server will run.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--max-num-batched-tokens",
 | 
				
			||||||
 | 
					        type=int,
 | 
				
			||||||
 | 
					        default=4096,
 | 
				
			||||||
 | 
					        help="Max tokens can be batched by this service.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--max-num-seqs",
 | 
				
			||||||
 | 
					        type=int,
 | 
				
			||||||
 | 
					        default=8,
 | 
				
			||||||
 | 
					        help="Max requests can be batched by this service.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    global max_num_seqs
 | 
				
			||||||
 | 
					    global max_num_batched_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    model_path = args.repo_id_or_model_path
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
| 
						 | 
					@ -236,10 +397,21 @@ if __name__ == "__main__":
 | 
				
			||||||
    max_num_seqs = args.max_num_seqs
 | 
					    max_num_seqs = args.max_num_seqs
 | 
				
			||||||
    max_num_batched_tokens = args.max_num_batched_tokens
 | 
					    max_num_batched_tokens = args.max_num_batched_tokens
 | 
				
			||||||
    load_model(model_path, low_bit)
 | 
					    load_model(model_path, low_bit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
 | 
				
			||||||
 | 
					    server = uvicorn.Server(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if local_rank == 0:
 | 
					    if local_rank == 0:
 | 
				
			||||||
        uvicorn.run(app, host="0.0.0.0", port=args.port)
 | 
					        await server.serve()
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            object_list = [None] * max_num_seqs
 | 
					            object_list = [None] * max_num_seqs
 | 
				
			||||||
            dist.broadcast_object_list(object_list, src=0)
 | 
					            dist.broadcast_object_list(object_list, src=0)
 | 
				
			||||||
            output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
 | 
					            await generate_stream_gate(
 | 
				
			||||||
 | 
					                [req.prompt for req in object_list],
 | 
				
			||||||
 | 
					                [req.n_predict for req in object_list],
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    asyncio.run(main())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,5 +33,4 @@ export TORCH_LLM_ALLREDUCE=0
 | 
				
			||||||
export WORLD_SIZE=2
 | 
					export WORLD_SIZE=2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mpirun -np $NUM_GPUS --prepend-rank \
 | 
					mpirun -np $NUM_GPUS --prepend-rank \
 | 
				
			||||||
        python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
 | 
					        python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'fp8' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
 | 
				
			||||||
 | 
					 | 
				
			||||||
							
								
								
									
										114
									
								
								python/llm/src/ipex_llm/transformers/streamer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								python/llm/src/ipex_llm/transformers/streamer.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,114 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional, List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from transformers import TextIteratorStreamer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BatchTextIteratorStreamer(TextIteratorStreamer):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    A specialized version of TextIteratorStreamer that handles text streams in batches, providing
 | 
				
			||||||
 | 
					    an efficient way to process large volumes of text data asynchronously. This class is designed
 | 
				
			||||||
 | 
					    to aggregate multiple texts into batches, making it ideal for applications that need to
 | 
				
			||||||
 | 
					    perform batch operations on streamed text data, such as bulk text processing or machine
 | 
				
			||||||
 | 
					    learning model inference in an interactive environment.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters:
 | 
				
			||||||
 | 
					                tokenizer (`AutoTokenizer`):
 | 
				
			||||||
 | 
					                        The tokenized used to decode the tokens.
 | 
				
			||||||
 | 
					                skip_prompt (`bool`, *optional*, defaults to `False`):
 | 
				
			||||||
 | 
					                        Whether to skip the prompt to `.generate()` or not.
 | 
				
			||||||
 | 
					                timeout (`float`, *optional*):
 | 
				
			||||||
 | 
					                        The timeout for the text queue. If `None`, the queue will
 | 
				
			||||||
 | 
					                        block indefinitely. Useful to handle exceptions
 | 
				
			||||||
 | 
					                        in `.generate()`, when it is called in a separate thread.
 | 
				
			||||||
 | 
					                decode_kwargs (`dict`, *optional*):
 | 
				
			||||||
 | 
					                        Additional keyword arguments to pass to the tokenizer's `decode` method.
 | 
				
			||||||
 | 
					                batch_size(`int`)
 | 
				
			||||||
 | 
					                        The size of the batches to process. This parameter must be specified and
 | 
				
			||||||
 | 
					                        determines how many texts are processed together as a single batch.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        batch_size: int,
 | 
				
			||||||
 | 
					        tokenizer: "AutoTokenizer",
 | 
				
			||||||
 | 
					        skip_prompt: bool = False,
 | 
				
			||||||
 | 
					        timeout: Optional[float] = None,
 | 
				
			||||||
 | 
					        **decode_kwargs
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
 | 
				
			||||||
 | 
					        self.batch_size = batch_size
 | 
				
			||||||
 | 
					        self.token_cache = [[] for _ in range(batch_size)]
 | 
				
			||||||
 | 
					        self.print_len = [0 for _ in range(batch_size)]
 | 
				
			||||||
 | 
					        self.generate_exception = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def put(self, value):
 | 
				
			||||||
 | 
					        if len(value.shape) != 2:
 | 
				
			||||||
 | 
					            value = torch.reshape(
 | 
				
			||||||
 | 
					                value, (self.batch_size, value.shape[0] // self.batch_size)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.skip_prompt and self.next_tokens_are_prompt:
 | 
				
			||||||
 | 
					            self.next_tokens_are_prompt = False
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        printable_texts = list()
 | 
				
			||||||
 | 
					        for idx in range(self.batch_size):
 | 
				
			||||||
 | 
					            self.token_cache[idx].extend(value[idx].tolist())
 | 
				
			||||||
 | 
					            text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if text.endswith("\n"):
 | 
				
			||||||
 | 
					                printable_text = text[self.print_len[idx]:]
 | 
				
			||||||
 | 
					                self.token_cache[idx] = []
 | 
				
			||||||
 | 
					                self.print_len[idx] = 0
 | 
				
			||||||
 | 
					                # If the last token is a CJK character, we print the characters.
 | 
				
			||||||
 | 
					            elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
 | 
				
			||||||
 | 
					                printable_text = text[self.print_len[idx]:]
 | 
				
			||||||
 | 
					                self.print_len[idx] += len(printable_text)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                printable_text = text[self.print_len[idx]:text.rfind(" ") + 1]
 | 
				
			||||||
 | 
					                self.print_len[idx] += len(printable_text)
 | 
				
			||||||
 | 
					            printable_texts.append(printable_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.on_finalized_text(printable_texts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def end(self):
 | 
				
			||||||
 | 
					        printable_texts = list()
 | 
				
			||||||
 | 
					        for idx in range(self.batch_size):
 | 
				
			||||||
 | 
					            if len(self.token_cache[idx]) > 0:
 | 
				
			||||||
 | 
					                text = self.tokenizer.decode(
 | 
				
			||||||
 | 
					                    self.token_cache[idx], **self.decode_kwargs
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                printable_text = text[self.print_len[idx]:]
 | 
				
			||||||
 | 
					                self.token_cache[idx] = []
 | 
				
			||||||
 | 
					                self.print_len[idx] = 0
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                printable_text = ""
 | 
				
			||||||
 | 
					            printable_texts.append(printable_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.next_tokens_are_prompt = True
 | 
				
			||||||
 | 
					        self.on_finalized_text(printable_texts, stream_end=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_finalized_text(self, texts: List[str], stream_end: bool = False):
 | 
				
			||||||
 | 
					        self.text_queue.put(texts, timeout=self.timeout)
 | 
				
			||||||
 | 
					        if stream_end:
 | 
				
			||||||
 | 
					            self.text_queue.put(self.stop_signal, timeout=self.timeout)
 | 
				
			||||||
		Loading…
	
		Reference in a new issue