Upgrade to vllm 0.6.2 (#12338)
* Initial updates for vllm 0.6.2 * fix * Change Dockerfile to support v062 * Fix * fix examples * Fix * done * fix * Update engine.py * Fix Dockerfile to original path * fix * add option * fix * fix * fix * fix --------- Co-authored-by: xiangyuT <xiangyu.tian@intel.com>
This commit is contained in:
		
							parent
							
								
									4376fdee62
								
							
						
					
					
						commit
						0ee54fc55f
					
				
					 13 changed files with 617 additions and 520 deletions
				
			
		| 
						 | 
					@ -5,6 +5,8 @@ ARG https_proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ENV TZ=Asia/Shanghai
 | 
					ENV TZ=Asia/Shanghai
 | 
				
			||||||
ENV PYTHONUNBUFFERED=1
 | 
					ENV PYTHONUNBUFFERED=1
 | 
				
			||||||
 | 
					# To prevent RPC_TIMEOUT ERROR for the first request
 | 
				
			||||||
 | 
					ENV VLLM_RPC_TIMEOUT=100000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Disable pip's cache behavior
 | 
					# Disable pip's cache behavior
 | 
				
			||||||
| 
						 | 
					@ -42,6 +44,14 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
 | 
				
			||||||
    pip install --upgrade colorama && \
 | 
					    pip install --upgrade colorama && \
 | 
				
			||||||
    # Download all-in-one benchmark and examples
 | 
					    # Download all-in-one benchmark and examples
 | 
				
			||||||
    git clone https://github.com/intel-analytics/ipex-llm && \
 | 
					    git clone https://github.com/intel-analytics/ipex-llm && \
 | 
				
			||||||
 | 
					    # The following comment segment is used when building from source...
 | 
				
			||||||
 | 
					    # cd ipex-llm && \
 | 
				
			||||||
 | 
					    # git fetch origin pull/12338/head:local_pr && \
 | 
				
			||||||
 | 
					    # git checkout local_pr && \
 | 
				
			||||||
 | 
					    # pip uninstall -y ipex-llm && \
 | 
				
			||||||
 | 
					    # cd python/llm && \
 | 
				
			||||||
 | 
					    # python setup.py install && \
 | 
				
			||||||
 | 
					    # cd ../../../ && \
 | 
				
			||||||
    cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \
 | 
					    cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \
 | 
				
			||||||
    cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \
 | 
					    cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \
 | 
				
			||||||
    # Install vllm dependencies
 | 
					    # Install vllm dependencies
 | 
				
			||||||
| 
						 | 
					@ -76,13 +86,16 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
 | 
				
			||||||
    rm -rf /tmp/neo && \
 | 
					    rm -rf /tmp/neo && \
 | 
				
			||||||
    mkdir -p /llm && \
 | 
					    mkdir -p /llm && \
 | 
				
			||||||
    cd /llm && \
 | 
					    cd /llm && \
 | 
				
			||||||
    git clone -b 0.5.4 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
 | 
					    git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
 | 
				
			||||||
    cd /llm/vllm && \
 | 
					    cd /llm/vllm && \
 | 
				
			||||||
    pip install -r /llm/vllm/requirements-xpu.txt && \
 | 
					    pip install setuptools-scm && \
 | 
				
			||||||
    VLLM_TARGET_DEVICE=xpu python setup.py install && \
 | 
					    pip install --upgrade cmake && \
 | 
				
			||||||
 | 
					    VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
 | 
				
			||||||
 | 
					    # pip install -r /llm/vllm/requirements-xpu.txt && \
 | 
				
			||||||
 | 
					    # VLLM_TARGET_DEVICE=xpu python setup.py install && \
 | 
				
			||||||
    pip install mpi4py fastapi uvicorn openai && \
 | 
					    pip install mpi4py fastapi uvicorn openai && \
 | 
				
			||||||
    pip install gradio==4.43.0 && \
 | 
					    pip install gradio==4.43.0 && \
 | 
				
			||||||
    pip install transformers==4.44.2 && \
 | 
					    # pip install transformers==4.44.2 && \
 | 
				
			||||||
    # patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \
 | 
					    # patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \
 | 
				
			||||||
    pip install ray && \
 | 
					    pip install ray && \
 | 
				
			||||||
    patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch
 | 
					    patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,14 +1,22 @@
 | 
				
			||||||
"""Benchmark offline inference throughput."""
 | 
					"""Benchmark offline inference throughput."""
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import List, Optional, Tuple
 | 
					from typing import List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					import uvloop
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
 | 
					from transformers import (AutoModelForCausalLM, AutoTokenizer,
 | 
				
			||||||
                          PreTrainedTokenizerBase)
 | 
					                          PreTrainedTokenizerBase)
 | 
				
			||||||
from tqdm import tqdm
 | 
					
 | 
				
			||||||
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
 | 
				
			||||||
 | 
					from vllm.entrypoints.openai.api_server import (
 | 
				
			||||||
 | 
					    build_async_engine_client_from_engine_args)
 | 
				
			||||||
 | 
					# from vllm.sampling_params import BeamSearchParams
 | 
				
			||||||
 | 
					from vllm.utils import FlexibleArgumentParser, merge_async_iterators
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def sample_requests(
 | 
					def sample_requests(
 | 
				
			||||||
| 
						 | 
					@ -29,22 +37,23 @@ def sample_requests(
 | 
				
			||||||
    dataset = [(data["conversations"][0]["value"],
 | 
					    dataset = [(data["conversations"][0]["value"],
 | 
				
			||||||
                data["conversations"][1]["value"]) for data in dataset]
 | 
					                data["conversations"][1]["value"]) for data in dataset]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Tokenize the prompts and completions.
 | 
					    # Shuffle the dataset.
 | 
				
			||||||
    prompts = [prompt for prompt, _ in dataset]
 | 
					    random.shuffle(dataset)
 | 
				
			||||||
    prompt_token_ids = tokenizer(prompts).input_ids
 | 
					 | 
				
			||||||
    completions = [completion for _, completion in dataset]
 | 
					 | 
				
			||||||
    completion_token_ids = tokenizer(completions).input_ids
 | 
					 | 
				
			||||||
    tokenized_dataset = []
 | 
					 | 
				
			||||||
    for i in range(len(dataset)):
 | 
					 | 
				
			||||||
        output_len = len(completion_token_ids[i])
 | 
					 | 
				
			||||||
        if fixed_output_len is not None:
 | 
					 | 
				
			||||||
            output_len = fixed_output_len
 | 
					 | 
				
			||||||
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Filter out too long sequences.
 | 
					    # Filter out sequences that are too long or too short
 | 
				
			||||||
    filtered_dataset: List[Tuple[str, int, int]] = []
 | 
					    filtered_dataset: List[Tuple[str, int, int]] = []
 | 
				
			||||||
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
 | 
					    for i in range(len(dataset)):
 | 
				
			||||||
 | 
					        if len(filtered_dataset) == num_requests:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Tokenize the prompts and completions.
 | 
				
			||||||
 | 
					        prompt = dataset[i][0]
 | 
				
			||||||
 | 
					        prompt_token_ids = tokenizer(prompt).input_ids
 | 
				
			||||||
 | 
					        completion = dataset[i][1]
 | 
				
			||||||
 | 
					        completion_token_ids = tokenizer(completion).input_ids
 | 
				
			||||||
        prompt_len = len(prompt_token_ids)
 | 
					        prompt_len = len(prompt_token_ids)
 | 
				
			||||||
 | 
					        output_len = len(completion_token_ids
 | 
				
			||||||
 | 
					                         ) if fixed_output_len is None else fixed_output_len
 | 
				
			||||||
        if prompt_len < 4 or output_len < 4:
 | 
					        if prompt_len < 4 or output_len < 4:
 | 
				
			||||||
            # Prune too short sequences.
 | 
					            # Prune too short sequences.
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
| 
						 | 
					@ -53,51 +62,18 @@ def sample_requests(
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
        filtered_dataset.append((prompt, prompt_len, output_len))
 | 
					        filtered_dataset.append((prompt, prompt_len, output_len))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Sample the requests.
 | 
					    return filtered_dataset
 | 
				
			||||||
    sampled_requests = random.sample(filtered_dataset, num_requests)
 | 
					 | 
				
			||||||
    return sampled_requests
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_vllm(
 | 
					def run_vllm(
 | 
				
			||||||
    requests: List[Tuple[str, int, int]],
 | 
					    requests: List[Tuple[str, int, int]],
 | 
				
			||||||
    model: str,
 | 
					 | 
				
			||||||
    tokenizer: str,
 | 
					 | 
				
			||||||
    quantization: Optional[str],
 | 
					 | 
				
			||||||
    tensor_parallel_size: int,
 | 
					 | 
				
			||||||
    seed: int,
 | 
					 | 
				
			||||||
    n: int,
 | 
					    n: int,
 | 
				
			||||||
    use_beam_search: bool,
 | 
					    low_bit: str,
 | 
				
			||||||
    trust_remote_code: bool,
 | 
					    engine_args: EngineArgs,
 | 
				
			||||||
    dtype: str,
 | 
					 | 
				
			||||||
    max_model_len: Optional[int],
 | 
					 | 
				
			||||||
    enforce_eager: bool,
 | 
					 | 
				
			||||||
    kv_cache_dtype: str,
 | 
					 | 
				
			||||||
    device: str,
 | 
					 | 
				
			||||||
    enable_prefix_caching: bool,
 | 
					 | 
				
			||||||
    gpu_memory_utilization: float = 0.9,
 | 
					 | 
				
			||||||
    load_in_low_bit: str = "sym_int4",
 | 
					 | 
				
			||||||
    max_num_batched_tokens: int = 5000,
 | 
					 | 
				
			||||||
    max_num_seqs: int = 256,
 | 
					 | 
				
			||||||
) -> float:
 | 
					) -> float:
 | 
				
			||||||
    from vllm import SamplingParams
 | 
					    from vllm import SamplingParams
 | 
				
			||||||
    from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
 | 
					    from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
 | 
				
			||||||
    llm = LLM(model=model,
 | 
					    llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit=low_bit)
 | 
				
			||||||
              tokenizer=tokenizer,
 | 
					 | 
				
			||||||
              quantization=quantization,
 | 
					 | 
				
			||||||
              tensor_parallel_size=tensor_parallel_size,
 | 
					 | 
				
			||||||
              seed=seed,
 | 
					 | 
				
			||||||
              trust_remote_code=trust_remote_code,
 | 
					 | 
				
			||||||
              dtype=dtype,
 | 
					 | 
				
			||||||
              max_model_len=max_model_len,
 | 
					 | 
				
			||||||
              gpu_memory_utilization=gpu_memory_utilization,
 | 
					 | 
				
			||||||
              enforce_eager=enforce_eager,
 | 
					 | 
				
			||||||
              kv_cache_dtype=kv_cache_dtype,
 | 
					 | 
				
			||||||
              device=device,
 | 
					 | 
				
			||||||
              enable_prefix_caching=enable_prefix_caching,
 | 
					 | 
				
			||||||
              load_in_low_bit=load_in_low_bit,
 | 
					 | 
				
			||||||
              max_num_batched_tokens=max_num_batched_tokens,
 | 
					 | 
				
			||||||
              max_num_seqs=max_num_seqs,)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Add the requests to the engine.
 | 
					    # Add the requests to the engine.
 | 
				
			||||||
    warm_prompt = "hi " * (1024 - 1)
 | 
					    warm_prompt = "hi " * (1024 - 1)
 | 
				
			||||||
| 
						 | 
					@ -111,14 +87,14 @@ def run_vllm(
 | 
				
			||||||
        sampling_params.append(
 | 
					        sampling_params.append(
 | 
				
			||||||
            SamplingParams(
 | 
					            SamplingParams(
 | 
				
			||||||
                n=n,
 | 
					                n=n,
 | 
				
			||||||
                temperature=0.0 if use_beam_search else 1.0,
 | 
					                temperature=0.0,
 | 
				
			||||||
                top_p=1.0,
 | 
					                top_p=1.0,
 | 
				
			||||||
                use_beam_search=use_beam_search,
 | 
					 | 
				
			||||||
                ignore_eos=True,
 | 
					                ignore_eos=True,
 | 
				
			||||||
                max_tokens=output_len,
 | 
					                max_tokens=output_len,
 | 
				
			||||||
            ))
 | 
					            ))
 | 
				
			||||||
    llm.generate(prompts, sampling_params, use_tqdm=True)
 | 
					    llm.generate(prompts, sampling_params, use_tqdm=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Add the requests to the engine.
 | 
				
			||||||
    prompts: List[str] = []
 | 
					    prompts: List[str] = []
 | 
				
			||||||
    sampling_params: List[SamplingParams] = []
 | 
					    sampling_params: List[SamplingParams] = []
 | 
				
			||||||
    for prompt, _, output_len in requests:
 | 
					    for prompt, _, output_len in requests:
 | 
				
			||||||
| 
						 | 
					@ -126,16 +102,67 @@ def run_vllm(
 | 
				
			||||||
        sampling_params.append(
 | 
					        sampling_params.append(
 | 
				
			||||||
            SamplingParams(
 | 
					            SamplingParams(
 | 
				
			||||||
                n=n,
 | 
					                n=n,
 | 
				
			||||||
                temperature=0.0 if use_beam_search else 1.0,
 | 
					                temperature=1.0,
 | 
				
			||||||
                top_p=1.0,
 | 
					                top_p=1.0,
 | 
				
			||||||
                use_beam_search=use_beam_search,
 | 
					 | 
				
			||||||
                ignore_eos=True,
 | 
					                ignore_eos=True,
 | 
				
			||||||
                max_tokens=output_len,
 | 
					                max_tokens=output_len,
 | 
				
			||||||
            ))
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    use_beam_search = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not use_beam_search:
 | 
				
			||||||
        start = time.perf_counter()
 | 
					        start = time.perf_counter()
 | 
				
			||||||
        llm.generate(prompts, sampling_params, use_tqdm=True)
 | 
					        llm.generate(prompts, sampling_params, use_tqdm=True)
 | 
				
			||||||
        end = time.perf_counter()
 | 
					        end = time.perf_counter()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        prompts = [prompt for prompt, _, _ in requests]
 | 
				
			||||||
 | 
					        # output_len should be the same for all requests.
 | 
				
			||||||
 | 
					        output_len = requests[0][2]
 | 
				
			||||||
 | 
					        for prompt, input_len, _output_len in requests:
 | 
				
			||||||
 | 
					            assert _output_len == output_len
 | 
				
			||||||
 | 
					        start = time.perf_counter()
 | 
				
			||||||
 | 
					        llm.beam_search(prompts,
 | 
				
			||||||
 | 
					                beam_width=n,
 | 
				
			||||||
 | 
					                max_tokens=output_len,
 | 
				
			||||||
 | 
					                ignore_eos=True)
 | 
				
			||||||
 | 
					        end = time.perf_counter()
 | 
				
			||||||
 | 
					    return end - start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def run_vllm_async(
 | 
				
			||||||
 | 
					    requests: List[Tuple[str, int, int]],
 | 
				
			||||||
 | 
					    n: int,
 | 
				
			||||||
 | 
					    engine_args: AsyncEngineArgs,
 | 
				
			||||||
 | 
					    disable_frontend_multiprocessing: bool = False,
 | 
				
			||||||
 | 
					) -> float:
 | 
				
			||||||
 | 
					    from vllm import SamplingParams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with build_async_engine_client_from_engine_args(
 | 
				
			||||||
 | 
					            engine_args, disable_frontend_multiprocessing) as llm:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Add the requests to the engine.
 | 
				
			||||||
 | 
					        prompts: List[str] = []
 | 
				
			||||||
 | 
					        sampling_params: List[SamplingParams] = []
 | 
				
			||||||
 | 
					        for prompt, _, output_len in requests:
 | 
				
			||||||
 | 
					            prompts.append(prompt)
 | 
				
			||||||
 | 
					            sampling_params.append(
 | 
				
			||||||
 | 
					                SamplingParams(
 | 
				
			||||||
 | 
					                    n=n,
 | 
				
			||||||
 | 
					                    temperature=1.0,
 | 
				
			||||||
 | 
					                    top_p=1.0,
 | 
				
			||||||
 | 
					                    ignore_eos=True,
 | 
				
			||||||
 | 
					                    max_tokens=output_len,
 | 
				
			||||||
 | 
					                ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        generators = []
 | 
				
			||||||
 | 
					        start = time.perf_counter()
 | 
				
			||||||
 | 
					        for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
 | 
				
			||||||
 | 
					            generator = llm.generate(prompt, sp, request_id=f"test{i}")
 | 
				
			||||||
 | 
					            generators.append(generator)
 | 
				
			||||||
 | 
					        all_gens = merge_async_iterators(*generators)
 | 
				
			||||||
 | 
					        async for i, res in all_gens:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        end = time.perf_counter()
 | 
				
			||||||
        return end - start
 | 
					        return end - start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -144,11 +171,9 @@ def run_hf(
 | 
				
			||||||
    model: str,
 | 
					    model: str,
 | 
				
			||||||
    tokenizer: PreTrainedTokenizerBase,
 | 
					    tokenizer: PreTrainedTokenizerBase,
 | 
				
			||||||
    n: int,
 | 
					    n: int,
 | 
				
			||||||
    use_beam_search: bool,
 | 
					 | 
				
			||||||
    max_batch_size: int,
 | 
					    max_batch_size: int,
 | 
				
			||||||
    trust_remote_code: bool,
 | 
					    trust_remote_code: bool,
 | 
				
			||||||
) -> float:
 | 
					) -> float:
 | 
				
			||||||
    assert not use_beam_search
 | 
					 | 
				
			||||||
    llm = AutoModelForCausalLM.from_pretrained(
 | 
					    llm = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
 | 
					        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
 | 
				
			||||||
    if llm.config.model_type == "llama":
 | 
					    if llm.config.model_type == "llama":
 | 
				
			||||||
| 
						 | 
					@ -180,7 +205,7 @@ def run_hf(
 | 
				
			||||||
                              padding=True).input_ids
 | 
					                              padding=True).input_ids
 | 
				
			||||||
        llm_outputs = llm.generate(
 | 
					        llm_outputs = llm.generate(
 | 
				
			||||||
            input_ids=input_ids.cuda(),
 | 
					            input_ids=input_ids.cuda(),
 | 
				
			||||||
            do_sample=not use_beam_search,
 | 
					            do_sample=True,
 | 
				
			||||||
            num_return_sequences=n,
 | 
					            num_return_sequences=n,
 | 
				
			||||||
            temperature=1.0,
 | 
					            temperature=1.0,
 | 
				
			||||||
            top_p=1.0,
 | 
					            top_p=1.0,
 | 
				
			||||||
| 
						 | 
					@ -205,13 +230,15 @@ def run_mii(
 | 
				
			||||||
    tensor_parallel_size: int,
 | 
					    tensor_parallel_size: int,
 | 
				
			||||||
    output_len: int,
 | 
					    output_len: int,
 | 
				
			||||||
) -> float:
 | 
					) -> float:
 | 
				
			||||||
    from mii import pipeline
 | 
					    from mii import client, serve
 | 
				
			||||||
    llm = pipeline(model, tensor_parallel=tensor_parallel_size)
 | 
					    llm = serve(model, tensor_parallel=tensor_parallel_size)
 | 
				
			||||||
    prompts = [prompt for prompt, _, _ in requests]
 | 
					    prompts = [prompt for prompt, _, _ in requests]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    start = time.perf_counter()
 | 
					    start = time.perf_counter()
 | 
				
			||||||
    llm(prompts, max_new_tokens=output_len)
 | 
					    llm.generate(prompts, max_new_tokens=output_len)
 | 
				
			||||||
    end = time.perf_counter()
 | 
					    end = time.perf_counter()
 | 
				
			||||||
 | 
					    client = client(model)
 | 
				
			||||||
 | 
					    client.terminate_server()
 | 
				
			||||||
    return end - start
 | 
					    return end - start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -224,7 +251,16 @@ def main(args: argparse.Namespace):
 | 
				
			||||||
        args.tokenizer, trust_remote_code=args.trust_remote_code)
 | 
					        args.tokenizer, trust_remote_code=args.trust_remote_code)
 | 
				
			||||||
    if args.dataset is None:
 | 
					    if args.dataset is None:
 | 
				
			||||||
        # Synthesize a prompt with the given input length.
 | 
					        # Synthesize a prompt with the given input length.
 | 
				
			||||||
        prompt = "hi" * (args.input_len - 1)
 | 
					        # As tokenizer may add additional tokens like BOS, we need to try
 | 
				
			||||||
 | 
					        # different lengths to get the desired input length.
 | 
				
			||||||
 | 
					        for i in range(-10, 10):
 | 
				
			||||||
 | 
					            prompt = "hi " * (args.input_len + i)
 | 
				
			||||||
 | 
					            tokenized_prompt = tokenizer(prompt).input_ids
 | 
				
			||||||
 | 
					            if len(tokenized_prompt) == args.input_len:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Failed to synthesize a prompt with {args.input_len} tokens.")
 | 
				
			||||||
        requests = [(prompt, args.input_len, args.output_len)
 | 
					        requests = [(prompt, args.input_len, args.output_len)
 | 
				
			||||||
                    for _ in range(args.num_prompts)]
 | 
					                    for _ in range(args.num_prompts)]
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
| 
						 | 
					@ -232,18 +268,21 @@ def main(args: argparse.Namespace):
 | 
				
			||||||
                                   args.output_len)
 | 
					                                   args.output_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.backend == "vllm":
 | 
					    if args.backend == "vllm":
 | 
				
			||||||
        elapsed_time = run_vllm(
 | 
					        if args.async_engine:
 | 
				
			||||||
            requests, args.model, args.tokenizer, args.quantization,
 | 
					            elapsed_time = uvloop.run(
 | 
				
			||||||
            args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
 | 
					                run_vllm_async(
 | 
				
			||||||
            args.trust_remote_code, args.dtype, args.max_model_len,
 | 
					                    requests,
 | 
				
			||||||
            args.enforce_eager, args.kv_cache_dtype, args.device,
 | 
					                    args.n,
 | 
				
			||||||
            args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit,
 | 
					                    AsyncEngineArgs.from_cli_args(args),
 | 
				
			||||||
            args.max_num_batched_tokens,args.max_num_seqs)
 | 
					                    args.disable_frontend_multiprocessing,
 | 
				
			||||||
 | 
					                ))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            elapsed_time = run_vllm(requests, args.n, args.load_in_low_bit,
 | 
				
			||||||
 | 
					                                    EngineArgs.from_cli_args(args))
 | 
				
			||||||
    elif args.backend == "hf":
 | 
					    elif args.backend == "hf":
 | 
				
			||||||
        assert args.tensor_parallel_size == 1
 | 
					        assert args.tensor_parallel_size == 1
 | 
				
			||||||
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
 | 
					        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
 | 
				
			||||||
                              args.use_beam_search, args.hf_max_batch_size,
 | 
					                              args.hf_max_batch_size, args.trust_remote_code)
 | 
				
			||||||
                              args.trust_remote_code)
 | 
					 | 
				
			||||||
    elif args.backend == "mii":
 | 
					    elif args.backend == "mii":
 | 
				
			||||||
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
 | 
					        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
 | 
				
			||||||
                               args.output_len)
 | 
					                               args.output_len)
 | 
				
			||||||
| 
						 | 
					@ -251,12 +290,26 @@ def main(args: argparse.Namespace):
 | 
				
			||||||
        raise ValueError(f"Unknown backend: {args.backend}")
 | 
					        raise ValueError(f"Unknown backend: {args.backend}")
 | 
				
			||||||
    total_num_tokens = sum(prompt_len + output_len
 | 
					    total_num_tokens = sum(prompt_len + output_len
 | 
				
			||||||
                           for _, prompt_len, output_len in requests)
 | 
					                           for _, prompt_len, output_len in requests)
 | 
				
			||||||
 | 
					    total_output_tokens = sum(output_len for _, _, output_len in requests)
 | 
				
			||||||
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
 | 
					    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
 | 
				
			||||||
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
 | 
					          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
 | 
				
			||||||
 | 
					          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Output JSON results if specified
 | 
				
			||||||
 | 
					    if args.output_json:
 | 
				
			||||||
 | 
					        results = {
 | 
				
			||||||
 | 
					            "elapsed_time": elapsed_time,
 | 
				
			||||||
 | 
					            "num_requests": len(requests),
 | 
				
			||||||
 | 
					            "total_num_tokens": total_num_tokens,
 | 
				
			||||||
 | 
					            "requests_per_second": len(requests) / elapsed_time,
 | 
				
			||||||
 | 
					            "tokens_per_second": total_num_tokens / elapsed_time,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        with open(args.output_json, "w") as f:
 | 
				
			||||||
 | 
					            json.dump(results, f, indent=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
 | 
					    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
 | 
				
			||||||
    parser.add_argument("--backend",
 | 
					    parser.add_argument("--backend",
 | 
				
			||||||
                        type=str,
 | 
					                        type=str,
 | 
				
			||||||
                        choices=["vllm", "hf", "mii"],
 | 
					                        choices=["vllm", "hf", "mii"],
 | 
				
			||||||
| 
						 | 
					@ -274,89 +327,38 @@ if __name__ == "__main__":
 | 
				
			||||||
                        default=None,
 | 
					                        default=None,
 | 
				
			||||||
                        help="Output length for each request. Overrides the "
 | 
					                        help="Output length for each request. Overrides the "
 | 
				
			||||||
                        "output length from the dataset.")
 | 
					                        "output length from the dataset.")
 | 
				
			||||||
    parser.add_argument("--model", type=str, default="facebook/opt-125m")
 | 
					 | 
				
			||||||
    parser.add_argument("--tokenizer", type=str, default=None)
 | 
					 | 
				
			||||||
    parser.add_argument('--quantization',
 | 
					 | 
				
			||||||
                        '-q',
 | 
					 | 
				
			||||||
                        choices=['awq', 'gptq', 'squeezellm', None],
 | 
					 | 
				
			||||||
                        default=None)
 | 
					 | 
				
			||||||
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
 | 
					 | 
				
			||||||
    parser.add_argument("--n",
 | 
					    parser.add_argument("--n",
 | 
				
			||||||
                        type=int,
 | 
					                        type=int,
 | 
				
			||||||
                        default=1,
 | 
					                        default=1,
 | 
				
			||||||
                        help="Number of generated sequences per prompt.")
 | 
					                        help="Number of generated sequences per prompt.")
 | 
				
			||||||
    parser.add_argument("--use-beam-search", action="store_true")
 | 
					 | 
				
			||||||
    parser.add_argument("--num-prompts",
 | 
					    parser.add_argument("--num-prompts",
 | 
				
			||||||
                        type=int,
 | 
					                        type=int,
 | 
				
			||||||
                        default=1000,
 | 
					                        default=1000,
 | 
				
			||||||
                        help="Number of prompts to process.")
 | 
					                        help="Number of prompts to process.")
 | 
				
			||||||
    parser.add_argument("--seed", type=int, default=0)
 | 
					 | 
				
			||||||
    parser.add_argument("--hf-max-batch-size",
 | 
					    parser.add_argument("--hf-max-batch-size",
 | 
				
			||||||
                        type=int,
 | 
					                        type=int,
 | 
				
			||||||
                        default=None,
 | 
					                        default=None,
 | 
				
			||||||
                        help="Maximum batch size for HF backend.")
 | 
					                        help="Maximum batch size for HF backend.")
 | 
				
			||||||
    parser.add_argument('--trust-remote-code',
 | 
					 | 
				
			||||||
                        action='store_true',
 | 
					 | 
				
			||||||
                        help='trust remote code from huggingface')
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        '--max-model-len',
 | 
					        '--output-json',
 | 
				
			||||||
        type=int,
 | 
					        type=str,
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
        help='Maximum length of a sequence (including prompt and output). '
 | 
					        help='Path to save the throughput results in JSON format.')
 | 
				
			||||||
        'If None, will be derived from the model.')
 | 
					    parser.add_argument("--async-engine",
 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        '--dtype',
 | 
					 | 
				
			||||||
        type=str,
 | 
					 | 
				
			||||||
        default='auto',
 | 
					 | 
				
			||||||
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
 | 
					 | 
				
			||||||
        help='data type for model weights and activations. '
 | 
					 | 
				
			||||||
        'The "auto" option will use FP16 precision '
 | 
					 | 
				
			||||||
        'for FP32 and FP16 models, and BF16 precision '
 | 
					 | 
				
			||||||
        'for BF16 models.')
 | 
					 | 
				
			||||||
    parser.add_argument('--gpu-memory-utilization',
 | 
					 | 
				
			||||||
                        type=float,
 | 
					 | 
				
			||||||
                        default=0.9,
 | 
					 | 
				
			||||||
                        help='the fraction of GPU memory to be used for '
 | 
					 | 
				
			||||||
                        'the model executor, which can range from 0 to 1.'
 | 
					 | 
				
			||||||
                        'If unspecified, will use the default value of 0.9.')
 | 
					 | 
				
			||||||
    parser.add_argument("--enforce-eager",
 | 
					 | 
				
			||||||
                        action="store_true",
 | 
					 | 
				
			||||||
                        help="enforce eager execution")
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--kv-cache-dtype",
 | 
					 | 
				
			||||||
        type=str,
 | 
					 | 
				
			||||||
        choices=["auto", "fp8_e5m2"],
 | 
					 | 
				
			||||||
        default="auto",
 | 
					 | 
				
			||||||
        help=
 | 
					 | 
				
			||||||
        'Data type for kv cache storage. If "auto", will use model data type.')
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--device",
 | 
					 | 
				
			||||||
        type=str,
 | 
					 | 
				
			||||||
        default="cuda",
 | 
					 | 
				
			||||||
        choices=["cuda", "xpu"],
 | 
					 | 
				
			||||||
        help='device type for vLLM execution, supporting CUDA only currently.')
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--enable-prefix-caching",
 | 
					 | 
				
			||||||
                        action='store_true',
 | 
					                        action='store_true',
 | 
				
			||||||
        help="enable automatic prefix caching for vLLM backend.")
 | 
					                        default=False,
 | 
				
			||||||
 | 
					                        help="Use vLLM async engine rather than LLM class.")
 | 
				
			||||||
 | 
					    parser.add_argument("--disable-frontend-multiprocessing",
 | 
				
			||||||
 | 
					                        action='store_true',
 | 
				
			||||||
 | 
					                        default=False,
 | 
				
			||||||
 | 
					                        help="Disable decoupled async engine frontend.")
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--load-in-low-bit",
 | 
					        "--load-in-low-bit",
 | 
				
			||||||
        type=str,
 | 
					        type=str,
 | 
				
			||||||
        choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
 | 
					        choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
 | 
				
			||||||
        default="sym_int4",
 | 
					        default="sym_int4",
 | 
				
			||||||
        help="Low-bit format quantization with IPEX-LLM")
 | 
					        help="Low-bit format quantization with IPEX-LLM")
 | 
				
			||||||
 | 
					    parser = AsyncEngineArgs.add_cli_args(parser)
 | 
				
			||||||
    parser.add_argument('--max-num-batched-tokens',
 | 
					 | 
				
			||||||
                        type=int,
 | 
					 | 
				
			||||||
                        default=4096,
 | 
					 | 
				
			||||||
                        help='maximum number of batched tokens per iteration')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    parser.add_argument('--max-num-seqs',
 | 
					 | 
				
			||||||
                        type=int,
 | 
					 | 
				
			||||||
                        default=256,
 | 
					 | 
				
			||||||
                        help='Maximum number of sequences per iteration.')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    if args.tokenizer is None:
 | 
					    if args.tokenizer is None:
 | 
				
			||||||
        args.tokenizer = args.model
 | 
					        args.tokenizer = args.model
 | 
				
			||||||
| 
						 | 
					@ -379,8 +381,6 @@ if __name__ == "__main__":
 | 
				
			||||||
            raise ValueError("dtype must be auto for MII backend.")
 | 
					            raise ValueError("dtype must be auto for MII backend.")
 | 
				
			||||||
        if args.n != 1:
 | 
					        if args.n != 1:
 | 
				
			||||||
            raise ValueError("n must be 1 for MII backend.")
 | 
					            raise ValueError("n must be 1 for MII backend.")
 | 
				
			||||||
        if args.use_beam_search:
 | 
					 | 
				
			||||||
            raise ValueError("Beam search is not supported for MII backend.")
 | 
					 | 
				
			||||||
        if args.quantization is not None:
 | 
					        if args.quantization is not None:
 | 
				
			||||||
            raise ValueError("Quantization is only for vLLM backend.")
 | 
					            raise ValueError("Quantization is only for vLLM backend.")
 | 
				
			||||||
        if args.hf_max_batch_size is not None:
 | 
					        if args.hf_max_batch_size is not None:
 | 
				
			||||||
| 
						 | 
					@ -388,5 +388,4 @@ if __name__ == "__main__":
 | 
				
			||||||
        if args.tokenizer != args.model:
 | 
					        if args.tokenizer != args.model:
 | 
				
			||||||
            raise ValueError("Tokenizer must be the same as the model for MII "
 | 
					            raise ValueError("Tokenizer must be the same as the model for MII "
 | 
				
			||||||
                             "backend.")
 | 
					                             "backend.")
 | 
				
			||||||
    main(args)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,4 +28,6 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
				
			||||||
  --max-model-len 2048 \
 | 
					  --max-model-len 2048 \
 | 
				
			||||||
  --max-num-batched-tokens 4000 \
 | 
					  --max-num-batched-tokens 4000 \
 | 
				
			||||||
  --max-num-seqs 12 \
 | 
					  --max-num-seqs 12 \
 | 
				
			||||||
  --tensor-parallel-size 1
 | 
					  --tensor-parallel-size 1 \
 | 
				
			||||||
 | 
					  --disable-async-output-proc \
 | 
				
			||||||
 | 
					  --distributed-executor-backend ray
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,6 +51,8 @@ llm = LLM(model="YOUR_MODEL",
 | 
				
			||||||
          enforce_eager=True,
 | 
					          enforce_eager=True,
 | 
				
			||||||
          load_in_low_bit="fp8",
 | 
					          load_in_low_bit="fp8",
 | 
				
			||||||
          tensor_parallel_size=1,
 | 
					          tensor_parallel_size=1,
 | 
				
			||||||
 | 
					          disable_async_output_proc=True,
 | 
				
			||||||
 | 
					          distributed_executor_backend="ray",
 | 
				
			||||||
          max_model_len=2000,
 | 
					          max_model_len=2000,
 | 
				
			||||||
          max_num_batched_tokens=2000)
 | 
					          max_num_batched_tokens=2000)
 | 
				
			||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
					# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
 | 
					This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.3.3).
 | 
					The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Currently, we support the following models for vLLM engine:
 | 
					Currently, we support the following models for vLLM engine:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 0. Environment
 | 
					### 0. Environment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.0. Please check the requirements at [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU#requirements).
 | 
					To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
 | 
					After install the toolkit, run the following commands in your environment before starting vLLM GPU:
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
| 
						 | 
					@ -44,14 +44,12 @@ conda create -n ipex-vllm python=3.11
 | 
				
			||||||
conda activate ipex-vllm
 | 
					conda activate ipex-vllm
 | 
				
			||||||
# Install dependencies
 | 
					# Install dependencies
 | 
				
			||||||
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
					pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
 | 
					pip install setuptools-scm
 | 
				
			||||||
 | 
					pip install --upgrade cmake
 | 
				
			||||||
# cd to your workdir
 | 
					# cd to your workdir
 | 
				
			||||||
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git
 | 
					git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git
 | 
				
			||||||
cd vllm
 | 
					cd vllm
 | 
				
			||||||
pip install -r requirements-xpu.txt
 | 
					VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v .
 | 
				
			||||||
pip install --no-deps xformers
 | 
					 | 
				
			||||||
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e .
 | 
					 | 
				
			||||||
pip install outlines==0.0.34 --no-deps
 | 
					 | 
				
			||||||
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy
 | 
					 | 
				
			||||||
# For Qwen model support
 | 
					# For Qwen model support
 | 
				
			||||||
pip install transformers_stream_generator einops tiktoken
 | 
					pip install transformers_stream_generator einops tiktoken
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
| 
						 | 
					@ -60,7 +58,8 @@ pip install transformers_stream_generator einops tiktoken
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
export USE_XETLA=OFF
 | 
					export USE_XETLA=OFF
 | 
				
			||||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
					export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
 | 
				
			||||||
 | 
					export SYCL_CACHE_PERSISTENT=1
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
### 3. Offline inference/Service
 | 
					### 3. Offline inference/Service
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,6 +85,7 @@ For vLLM, you can start the service using the following command:
 | 
				
			||||||
#!/bin/bash
 | 
					#!/bin/bash
 | 
				
			||||||
model="YOUR_MODEL_PATH"
 | 
					model="YOUR_MODEL_PATH"
 | 
				
			||||||
served_model_name="YOUR_MODEL_NAME"
 | 
					served_model_name="YOUR_MODEL_NAME"
 | 
				
			||||||
 | 
					export VLLM_RPC_TIMEOUT=100000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 # You may need to adjust the value of
 | 
					 # You may need to adjust the value of
 | 
				
			||||||
 # --max-model-len, --max-num-batched-tokens, --max-num-seqs
 | 
					 # --max-model-len, --max-num-batched-tokens, --max-num-seqs
 | 
				
			||||||
| 
						 | 
					@ -104,7 +104,8 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
				
			||||||
  --max-model-len 4096 \
 | 
					  --max-model-len 4096 \
 | 
				
			||||||
  --max-num-batched-tokens 10240 \
 | 
					  --max-num-batched-tokens 10240 \
 | 
				
			||||||
  --max-num-seqs 12 \
 | 
					  --max-num-seqs 12 \
 | 
				
			||||||
  --tensor-parallel-size 1
 | 
					  --tensor-parallel-size 1 \
 | 
				
			||||||
 | 
					  --disable-async-output-proc
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can tune the service using these four arguments:
 | 
					You can tune the service using these four arguments:
 | 
				
			||||||
| 
						 | 
					@ -200,5 +201,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
				
			||||||
  --max-model-len 4096 \
 | 
					  --max-model-len 4096 \
 | 
				
			||||||
  --max-num-batched-tokens 10240 \
 | 
					  --max-num-batched-tokens 10240 \
 | 
				
			||||||
  --max-num-seqs 12 \
 | 
					  --max-num-seqs 12 \
 | 
				
			||||||
  --tensor-parallel-size 2
 | 
					  --tensor-parallel-size 2 \
 | 
				
			||||||
 | 
					  --distributed-executor-backend ray \
 | 
				
			||||||
 | 
					  --disable-async-output-proc
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,8 +49,10 @@ llm = LLM(model="YOUR_MODEL",
 | 
				
			||||||
          device="xpu",
 | 
					          device="xpu",
 | 
				
			||||||
          dtype="float16",
 | 
					          dtype="float16",
 | 
				
			||||||
          enforce_eager=True,
 | 
					          enforce_eager=True,
 | 
				
			||||||
          load_in_low_bit="sym_int4",
 | 
					          load_in_low_bit="fp8",
 | 
				
			||||||
          tensor_parallel_size=1)
 | 
					          tensor_parallel_size=1,
 | 
				
			||||||
 | 
					          max_model_len=2000,
 | 
				
			||||||
 | 
					          max_num_batched_tokens=2000)
 | 
				
			||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
					# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
				
			||||||
# that contain the prompt, generated text, and other information.
 | 
					# that contain the prompt, generated text, and other information.
 | 
				
			||||||
outputs = llm.generate(prompts, sampling_params)
 | 
					outputs = llm.generate(prompts, sampling_params)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,7 +93,7 @@ class VLLMWorker(BaseModelWorker):
 | 
				
			||||||
        request_id = params.pop("request_id")
 | 
					        request_id = params.pop("request_id")
 | 
				
			||||||
        temperature = float(params.get("temperature", 1.0))
 | 
					        temperature = float(params.get("temperature", 1.0))
 | 
				
			||||||
        top_p = float(params.get("top_p", 1.0))
 | 
					        top_p = float(params.get("top_p", 1.0))
 | 
				
			||||||
        top_k = params.get("top_k", -1.0)
 | 
					        top_k = params.get("top_k", -1)
 | 
				
			||||||
        presence_penalty = float(params.get("presence_penalty", 0.0))
 | 
					        presence_penalty = float(params.get("presence_penalty", 0.0))
 | 
				
			||||||
        frequency_penalty = float(params.get("frequency_penalty", 0.0))
 | 
					        frequency_penalty = float(params.get("frequency_penalty", 0.0))
 | 
				
			||||||
        max_new_tokens = params.get("max_new_tokens", 256)
 | 
					        max_new_tokens = params.get("max_new_tokens", 256)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,9 +13,10 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass
 | 
					from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
    "IPEXLLMAsyncLLMEngine",
 | 
					    "IPEXLLMAsyncLLMEngine",
 | 
				
			||||||
    "IPEXLLMLLMEngine",
 | 
					    "IPEXLLMLLMEngine",
 | 
				
			||||||
    "IPEXLLMClass",
 | 
					    "IPEXLLMClass",
 | 
				
			||||||
 | 
					    "run_mp_engine",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,9 +19,12 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
 | 
				
			||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
 | 
				
			||||||
from vllm.entrypoints.llm import LLM
 | 
					from vllm.entrypoints.llm import LLM
 | 
				
			||||||
from vllm.utils import Counter
 | 
					from vllm.utils import Counter
 | 
				
			||||||
 | 
					from vllm.config import EngineConfig
 | 
				
			||||||
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
 | 
					from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
 | 
				
			||||||
from vllm.usage.usage_lib import UsageContext
 | 
					from vllm.usage.usage_lib import UsageContext
 | 
				
			||||||
from vllm.engine.metrics import StatLoggerBase
 | 
					from vllm.engine.metrics import StatLoggerBase
 | 
				
			||||||
 | 
					from vllm.engine.multiprocessing.engine import MQLLMEngine
 | 
				
			||||||
 | 
					import signal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
					class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
				
			||||||
| 
						 | 
					@ -32,6 +35,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
				
			||||||
    def from_engine_args(
 | 
					    def from_engine_args(
 | 
				
			||||||
        cls,
 | 
					        cls,
 | 
				
			||||||
        engine_args: AsyncEngineArgs,
 | 
					        engine_args: AsyncEngineArgs,
 | 
				
			||||||
 | 
					        engine_config: Optional[EngineConfig] = None,
 | 
				
			||||||
        start_engine_loop: bool = True,
 | 
					        start_engine_loop: bool = True,
 | 
				
			||||||
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
					        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
				
			||||||
        load_in_low_bit: str = "sym_int4",
 | 
					        load_in_low_bit: str = "sym_int4",
 | 
				
			||||||
| 
						 | 
					@ -40,7 +44,9 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
				
			||||||
        """Creates an async LLM engine from the engine arguments."""
 | 
					        """Creates an async LLM engine from the engine arguments."""
 | 
				
			||||||
        # Create the engine configs.
 | 
					        # Create the engine configs.
 | 
				
			||||||
        _ipex_llm_convert(load_in_low_bit)
 | 
					        _ipex_llm_convert(load_in_low_bit)
 | 
				
			||||||
        return super().from_engine_args(engine_args, start_engine_loop, usage_context, stat_loggers)
 | 
					        return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
 | 
				
			||||||
 | 
					                                        start_engine_loop=start_engine_loop,
 | 
				
			||||||
 | 
					                                        usage_context=usage_context, stat_loggers=stat_loggers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class IPEXLLMClass(LLM):
 | 
					class IPEXLLMClass(LLM):
 | 
				
			||||||
| 
						 | 
					@ -117,3 +123,27 @@ class IPEXLLMLLMEngine(LLMEngine):
 | 
				
			||||||
        # Create the engine configs.
 | 
					        # Create the engine configs.
 | 
				
			||||||
        _ipex_llm_convert(load_in_low_bit)
 | 
					        _ipex_llm_convert(load_in_low_bit)
 | 
				
			||||||
        return super().from_engine_args(engine_args, usage_context, stat_loggers)
 | 
					        return super().from_engine_args(engine_args, usage_context, stat_loggers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IPEXLLMMQLLMEngine(MQLLMEngine):
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_engine_args(cls, engine_args: AsyncEngineArgs,
 | 
				
			||||||
 | 
					                         usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
 | 
				
			||||||
 | 
					        _ipex_llm_convert(load_in_low_bit)
 | 
				
			||||||
 | 
					        return super().from_engine_args(engine_args, usage_context, ipc_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
 | 
				
			||||||
 | 
					                  ipc_path: str, load_in_low_bit: str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def signal_handler(*_) -> None:
 | 
				
			||||||
 | 
					        # Interrupt server on sigterm
 | 
				
			||||||
 | 
					        raise KeyboardInterrupt("MQLLMEngine terminated")  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    signal.signal(signal.SIGTERM, signal_handler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
 | 
				
			||||||
 | 
					                                                 usage_context=usage_context,
 | 
				
			||||||
 | 
					                                                 ipc_path=ipc_path,
 | 
				
			||||||
 | 
					                                                 load_in_low_bit=load_in_low_bit)
 | 
				
			||||||
 | 
					    engine.start()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,25 +1,34 @@
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
 | 
					import multiprocessing
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
 | 
					import signal
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import tempfile
 | 
				
			||||||
from argparse import Namespace
 | 
					from argparse import Namespace
 | 
				
			||||||
from contextlib import asynccontextmanager
 | 
					from contextlib import asynccontextmanager
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
from http import HTTPStatus
 | 
					from http import HTTPStatus
 | 
				
			||||||
from multiprocessing import Process
 | 
					 | 
				
			||||||
from typing import AsyncIterator, Set
 | 
					from typing import AsyncIterator, Set
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import uvloop
 | 
				
			||||||
from fastapi import APIRouter, FastAPI, Request
 | 
					from fastapi import APIRouter, FastAPI, Request
 | 
				
			||||||
from fastapi.exceptions import RequestValidationError
 | 
					from fastapi.exceptions import RequestValidationError
 | 
				
			||||||
from fastapi.middleware.cors import CORSMiddleware
 | 
					from fastapi.middleware.cors import CORSMiddleware
 | 
				
			||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
 | 
					from fastapi.responses import JSONResponse, Response, StreamingResponse
 | 
				
			||||||
from prometheus_client import make_asgi_app
 | 
					from starlette.datastructures import State
 | 
				
			||||||
from starlette.routing import Mount
 | 
					from starlette.routing import Mount
 | 
				
			||||||
 | 
					from typing_extensions import assert_never
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import vllm.envs as envs
 | 
					import vllm.envs as envs
 | 
				
			||||||
from vllm.config import ModelConfig
 | 
					from vllm.config import ModelConfig
 | 
				
			||||||
from vllm.engine.arg_utils import AsyncEngineArgs
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs
 | 
				
			||||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
 | 
					from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
 | 
				
			||||||
from vllm.engine.protocol import AsyncEngineClient
 | 
					from vllm.engine.multiprocessing.client import MQLLMEngineClient
 | 
				
			||||||
 | 
					from ipex_llm.vllm.xpu.engine import run_mp_engine
 | 
				
			||||||
 | 
					from vllm.engine.protocol import EngineClient
 | 
				
			||||||
from vllm.entrypoints.launcher import serve_http
 | 
					from vllm.entrypoints.launcher import serve_http
 | 
				
			||||||
from vllm.entrypoints.logger import RequestLogger
 | 
					from vllm.entrypoints.logger import RequestLogger
 | 
				
			||||||
from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
 | 
					from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
 | 
				
			||||||
| 
						 | 
					@ -28,154 +37,269 @@ from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
 | 
				
			||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
 | 
					from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
 | 
				
			||||||
                                              ChatCompletionResponse,
 | 
					                                              ChatCompletionResponse,
 | 
				
			||||||
                                              CompletionRequest,
 | 
					                                              CompletionRequest,
 | 
				
			||||||
 | 
					                                              CompletionResponse,
 | 
				
			||||||
                                              DetokenizeRequest,
 | 
					                                              DetokenizeRequest,
 | 
				
			||||||
                                              DetokenizeResponse,
 | 
					                                              DetokenizeResponse,
 | 
				
			||||||
                                              EmbeddingRequest, ErrorResponse,
 | 
					                                              EmbeddingRequest,
 | 
				
			||||||
 | 
					                                              EmbeddingResponse, ErrorResponse,
 | 
				
			||||||
 | 
					                                              LoadLoraAdapterRequest,
 | 
				
			||||||
                                              TokenizeRequest,
 | 
					                                              TokenizeRequest,
 | 
				
			||||||
                                              TokenizeResponse)
 | 
					                                              TokenizeResponse,
 | 
				
			||||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
 | 
					                                              UnloadLoraAdapterRequest)
 | 
				
			||||||
from ipex_llm.vllm.xpu.entrypoints.openai.rpc.server import run_rpc_server
 | 
					 | 
				
			||||||
# yapf: enable
 | 
					# yapf: enable
 | 
				
			||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
 | 
					from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
 | 
				
			||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
 | 
					from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
 | 
				
			||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
 | 
					from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
 | 
				
			||||||
 | 
					from vllm.entrypoints.openai.serving_engine import BaseModelPath
 | 
				
			||||||
from vllm.entrypoints.openai.serving_tokenization import (
 | 
					from vllm.entrypoints.openai.serving_tokenization import (
 | 
				
			||||||
    OpenAIServingTokenization)
 | 
					    OpenAIServingTokenization)
 | 
				
			||||||
from vllm.logger import init_logger
 | 
					from vllm.logger import init_logger
 | 
				
			||||||
from vllm.usage.usage_lib import UsageContext
 | 
					from vllm.usage.usage_lib import UsageContext
 | 
				
			||||||
from vllm.utils import FlexibleArgumentParser, get_open_port
 | 
					from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
 | 
				
			||||||
from vllm.version import __version__ as VLLM_VERSION
 | 
					from vllm.version import __version__ as VLLM_VERSION
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TIMEOUT_KEEP_ALIVE = 5  # seconds
 | 
					TIMEOUT_KEEP_ALIVE = 5  # seconds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async_engine_client: AsyncEngineClient
 | 
					prometheus_multiproc_dir: tempfile.TemporaryDirectory
 | 
				
			||||||
engine_args: AsyncEngineArgs
 | 
					 | 
				
			||||||
openai_serving_chat: OpenAIServingChat
 | 
					 | 
				
			||||||
openai_serving_completion: OpenAIServingCompletion
 | 
					 | 
				
			||||||
openai_serving_embedding: OpenAIServingEmbedding
 | 
					 | 
				
			||||||
openai_serving_tokenization: OpenAIServingTokenization
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
 | 
				
			||||||
logger = init_logger('vllm.entrypoints.openai.api_server')
 | 
					logger = init_logger('vllm.entrypoints.openai.api_server')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_running_tasks: Set[asyncio.Task] = set()
 | 
					_running_tasks: Set[asyncio.Task] = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
 | 
					 | 
				
			||||||
    return ModelConfig(model=model_name,
 | 
					 | 
				
			||||||
                       tokenizer=model_name,
 | 
					 | 
				
			||||||
                       tokenizer_mode="auto",
 | 
					 | 
				
			||||||
                       trust_remote_code=trust_remote_code,
 | 
					 | 
				
			||||||
                       seed=0,
 | 
					 | 
				
			||||||
                       dtype="float16").embedding_mode
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@asynccontextmanager
 | 
					@asynccontextmanager
 | 
				
			||||||
async def lifespan(app: FastAPI):
 | 
					async def lifespan(app: FastAPI):
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        if app.state.log_stats:
 | 
				
			||||||
 | 
					            engine_client: EngineClient = app.state.engine_client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            async def _force_log():
 | 
					            async def _force_log():
 | 
				
			||||||
                while True:
 | 
					                while True:
 | 
				
			||||||
            await asyncio.sleep(10)
 | 
					                    await asyncio.sleep(10.)
 | 
				
			||||||
            await async_engine_client.do_log_stats()
 | 
					                    await engine_client.do_log_stats()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not engine_args.disable_log_stats:
 | 
					 | 
				
			||||||
            task = asyncio.create_task(_force_log())
 | 
					            task = asyncio.create_task(_force_log())
 | 
				
			||||||
            _running_tasks.add(task)
 | 
					            _running_tasks.add(task)
 | 
				
			||||||
            task.add_done_callback(_running_tasks.remove)
 | 
					            task.add_done_callback(_running_tasks.remove)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            task = None
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
            yield
 | 
					            yield
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            if task is not None:
 | 
				
			||||||
 | 
					                task.cancel()
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        # Ensure app state including engine ref is gc'd
 | 
				
			||||||
 | 
					        del app.state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@asynccontextmanager
 | 
					@asynccontextmanager
 | 
				
			||||||
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
 | 
					async def build_async_engine_client(
 | 
				
			||||||
    # Context manager to handle async_engine_client lifecycle
 | 
					        args: Namespace) -> AsyncIterator[EngineClient]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Context manager to handle engine_client lifecycle
 | 
				
			||||||
    # Ensures everything is shutdown and cleaned up on error/exit
 | 
					    # Ensures everything is shutdown and cleaned up on error/exit
 | 
				
			||||||
    global engine_args
 | 
					 | 
				
			||||||
    engine_args = AsyncEngineArgs.from_cli_args(args)
 | 
					    engine_args = AsyncEngineArgs.from_cli_args(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Backend itself still global for the silly lil' health handler
 | 
					    async with build_async_engine_client_from_engine_args(
 | 
				
			||||||
    global async_engine_client
 | 
					            engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine:
 | 
				
			||||||
 | 
					        yield engine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # If manually triggered or embedding model, use AsyncLLMEngine in process.
 | 
					
 | 
				
			||||||
    # TODO: support embedding model via RPC.
 | 
					@asynccontextmanager
 | 
				
			||||||
    if (model_is_embedding(args.model, args.trust_remote_code)
 | 
					async def build_async_engine_client_from_engine_args(
 | 
				
			||||||
            or args.disable_frontend_multiprocessing):
 | 
					    engine_args: AsyncEngineArgs,
 | 
				
			||||||
        async_engine_client = AsyncLLMEngine.from_engine_args(
 | 
					    disable_frontend_multiprocessing: bool = False,
 | 
				
			||||||
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER,
 | 
					    load_in_low_bit: str = 'sym_int4',
 | 
				
			||||||
            load_in_low_bit=args.load_in_low_bit)
 | 
					) -> AsyncIterator[EngineClient]:
 | 
				
			||||||
        yield async_engine_client
 | 
					    """
 | 
				
			||||||
 | 
					    Create EngineClient, either:
 | 
				
			||||||
 | 
					        - in-process using the AsyncLLMEngine Directly
 | 
				
			||||||
 | 
					        - multiprocess using AsyncLLMEngine RPC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns the Client or None if the creation failed.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Fall back
 | 
				
			||||||
 | 
					    # TODO: fill out feature matrix.
 | 
				
			||||||
 | 
					    if (MQLLMEngineClient.is_unsupported_config(engine_args)
 | 
				
			||||||
 | 
					            or disable_frontend_multiprocessing):
 | 
				
			||||||
 | 
					        engine_config = engine_args.create_engine_config()
 | 
				
			||||||
 | 
					        uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
 | 
				
			||||||
 | 
					                           "uses_ray", False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        build_engine = partial(AsyncLLMEngine.from_engine_args,
 | 
				
			||||||
 | 
					                               engine_args=engine_args,
 | 
				
			||||||
 | 
					                               load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                               engine_config=engine_config,
 | 
				
			||||||
 | 
					                               usage_context=UsageContext.OPENAI_API_SERVER)
 | 
				
			||||||
 | 
					        if uses_ray:
 | 
				
			||||||
 | 
					            # Must run in main thread with ray for its signal handlers to work
 | 
				
			||||||
 | 
					            engine_client = build_engine()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            engine_client = await asyncio.get_running_loop().run_in_executor(
 | 
				
			||||||
 | 
					                None, build_engine)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        yield engine_client
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Otherwise, use the multiprocessing AsyncLLMEngine.
 | 
					    # Otherwise, use the multiprocessing AsyncLLMEngine.
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
 | 
					        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
 | 
				
			||||||
        port = get_open_port(envs.VLLM_RPC_PORT)
 | 
					            # Make TemporaryDirectory for prometheus multiprocessing
 | 
				
			||||||
        load_in_low_bit = args.load_in_low_bit
 | 
					            # Note: global TemporaryDirectory will be automatically
 | 
				
			||||||
        rpc_server_process = Process(target=run_rpc_server,
 | 
					            #   cleaned up upon exit.
 | 
				
			||||||
 | 
					            global prometheus_multiproc_dir
 | 
				
			||||||
 | 
					            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
 | 
				
			||||||
 | 
					            os.environ[
 | 
				
			||||||
 | 
					                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            logger.warning(
 | 
				
			||||||
 | 
					                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
 | 
				
			||||||
 | 
					                "This directory must be wiped between vLLM runs or "
 | 
				
			||||||
 | 
					                "you will find inaccurate metrics. Unset the variable "
 | 
				
			||||||
 | 
					                "and vLLM will properly handle cleanup.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Select random path for IPC.
 | 
				
			||||||
 | 
					        ipc_path = get_open_zmq_ipc_path()
 | 
				
			||||||
 | 
					        logger.info("Multiprocessing frontend to use %s for IPC Path.",
 | 
				
			||||||
 | 
					                    ipc_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Start RPCServer in separate process (holds the LLMEngine).
 | 
				
			||||||
 | 
					        # the current process might have CUDA context,
 | 
				
			||||||
 | 
					        # so we need to spawn a new process
 | 
				
			||||||
 | 
					        context = multiprocessing.get_context("spawn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        engine_process = context.Process(target=run_mp_engine,
 | 
				
			||||||
                                         args=(engine_args,
 | 
					                                         args=(engine_args,
 | 
				
			||||||
                                               UsageContext.OPENAI_API_SERVER,
 | 
					                                               UsageContext.OPENAI_API_SERVER,
 | 
				
			||||||
                                           port, load_in_low_bit))
 | 
					                                               ipc_path,
 | 
				
			||||||
        rpc_server_process.start()
 | 
					                                               load_in_low_bit))
 | 
				
			||||||
 | 
					        engine_process.start()
 | 
				
			||||||
 | 
					        logger.info("Started engine process with PID %d", engine_process.pid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
 | 
					        # Build RPCClient, which conforms to EngineClient Protocol.
 | 
				
			||||||
        async_engine_client = AsyncEngineRPCClient(port)
 | 
					        # NOTE: Actually, this is not true yet. We still need to support
 | 
				
			||||||
        await async_engine_client.setup()
 | 
					        # embedding models via RPC (see TODO above)
 | 
				
			||||||
 | 
					        engine_config = engine_args.create_engine_config()
 | 
				
			||||||
 | 
					        mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            yield async_engine_client
 | 
					            while True:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    await mp_engine_client.setup()
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					                except TimeoutError:
 | 
				
			||||||
 | 
					                    if not engine_process.is_alive():
 | 
				
			||||||
 | 
					                        raise RuntimeError(
 | 
				
			||||||
 | 
					                            "Engine process failed to start") from None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            yield mp_engine_client  # type: ignore[misc]
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            # Ensure rpc server process was terminated
 | 
					            # Ensure rpc server process was terminated
 | 
				
			||||||
            rpc_server_process.terminate()
 | 
					            engine_process.terminate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Close all open connections to the backend
 | 
					            # Close all open connections to the backend
 | 
				
			||||||
            async_engine_client.close()
 | 
					            mp_engine_client.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Wait for server process to join
 | 
					            # Wait for engine process to join
 | 
				
			||||||
            rpc_server_process.join()
 | 
					            engine_process.join(4)
 | 
				
			||||||
 | 
					            if engine_process.exitcode is None:
 | 
				
			||||||
 | 
					                # Kill if taking longer than 5 seconds to stop
 | 
				
			||||||
 | 
					                engine_process.kill()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Lazy import for prometheus multiprocessing.
 | 
				
			||||||
 | 
					            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
 | 
				
			||||||
 | 
					            # before prometheus_client is imported.
 | 
				
			||||||
 | 
					            # See https://prometheus.github.io/client_python/multiprocess/
 | 
				
			||||||
 | 
					            from prometheus_client import multiprocess
 | 
				
			||||||
 | 
					            multiprocess.mark_process_dead(engine_process.pid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
router = APIRouter()
 | 
					router = APIRouter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mount_metrics(app: FastAPI):
 | 
					def mount_metrics(app: FastAPI):
 | 
				
			||||||
 | 
					    # Lazy import for prometheus multiprocessing.
 | 
				
			||||||
 | 
					    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
 | 
				
			||||||
 | 
					    # before prometheus_client is imported.
 | 
				
			||||||
 | 
					    # See https://prometheus.github.io/client_python/multiprocess/
 | 
				
			||||||
 | 
					    from prometheus_client import (CollectorRegistry, make_asgi_app,
 | 
				
			||||||
 | 
					                                   multiprocess)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
 | 
				
			||||||
 | 
					    if prometheus_multiproc_dir_path is not None:
 | 
				
			||||||
 | 
					        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
 | 
				
			||||||
 | 
					                    prometheus_multiproc_dir_path)
 | 
				
			||||||
 | 
					        registry = CollectorRegistry()
 | 
				
			||||||
 | 
					        multiprocess.MultiProcessCollector(registry)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Add prometheus asgi middleware to route /metrics requests
 | 
				
			||||||
 | 
					        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
        # Add prometheus asgi middleware to route /metrics requests
 | 
					        # Add prometheus asgi middleware to route /metrics requests
 | 
				
			||||||
        metrics_route = Mount("/metrics", make_asgi_app())
 | 
					        metrics_route = Mount("/metrics", make_asgi_app())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Workaround for 307 Redirect for /metrics
 | 
					    # Workaround for 307 Redirect for /metrics
 | 
				
			||||||
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
 | 
					    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
 | 
				
			||||||
    app.routes.append(metrics_route)
 | 
					    app.routes.append(metrics_route)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def chat(request: Request) -> OpenAIServingChat:
 | 
				
			||||||
 | 
					    return request.app.state.openai_serving_chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def completion(request: Request) -> OpenAIServingCompletion:
 | 
				
			||||||
 | 
					    return request.app.state.openai_serving_completion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def tokenization(request: Request) -> OpenAIServingTokenization:
 | 
				
			||||||
 | 
					    return request.app.state.openai_serving_tokenization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def embedding(request: Request) -> OpenAIServingEmbedding:
 | 
				
			||||||
 | 
					    return request.app.state.openai_serving_embedding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def engine_client(request: Request) -> EngineClient:
 | 
				
			||||||
 | 
					    return request.app.state.engine_client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.get("/health")
 | 
					@router.get("/health")
 | 
				
			||||||
async def health() -> Response:
 | 
					async def health(raw_request: Request) -> Response:
 | 
				
			||||||
    """Health check."""
 | 
					    """Health check."""
 | 
				
			||||||
    await async_engine_client.check_health()
 | 
					    await engine_client(raw_request).check_health()
 | 
				
			||||||
    return Response(status_code=200)
 | 
					    return Response(status_code=200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.post("/tokenize")
 | 
					@router.post("/tokenize")
 | 
				
			||||||
async def tokenize(request: TokenizeRequest):
 | 
					async def tokenize(request: TokenizeRequest, raw_request: Request):
 | 
				
			||||||
    generator = await openai_serving_tokenization.create_tokenize(request)
 | 
					    generator = await tokenization(raw_request).create_tokenize(request)
 | 
				
			||||||
    if isinstance(generator, ErrorResponse):
 | 
					    if isinstance(generator, ErrorResponse):
 | 
				
			||||||
        return JSONResponse(content=generator.model_dump(),
 | 
					        return JSONResponse(content=generator.model_dump(),
 | 
				
			||||||
                            status_code=generator.code)
 | 
					                            status_code=generator.code)
 | 
				
			||||||
    else:
 | 
					    elif isinstance(generator, TokenizeResponse):
 | 
				
			||||||
        assert isinstance(generator, TokenizeResponse)
 | 
					 | 
				
			||||||
        return JSONResponse(content=generator.model_dump())
 | 
					        return JSONResponse(content=generator.model_dump())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert_never(generator)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.post("/detokenize")
 | 
					@router.post("/detokenize")
 | 
				
			||||||
async def detokenize(request: DetokenizeRequest):
 | 
					async def detokenize(request: DetokenizeRequest, raw_request: Request):
 | 
				
			||||||
    generator = await openai_serving_tokenization.create_detokenize(request)
 | 
					    generator = await tokenization(raw_request).create_detokenize(request)
 | 
				
			||||||
    if isinstance(generator, ErrorResponse):
 | 
					    if isinstance(generator, ErrorResponse):
 | 
				
			||||||
        return JSONResponse(content=generator.model_dump(),
 | 
					        return JSONResponse(content=generator.model_dump(),
 | 
				
			||||||
                            status_code=generator.code)
 | 
					                            status_code=generator.code)
 | 
				
			||||||
    else:
 | 
					    elif isinstance(generator, DetokenizeResponse):
 | 
				
			||||||
        assert isinstance(generator, DetokenizeResponse)
 | 
					 | 
				
			||||||
        return JSONResponse(content=generator.model_dump())
 | 
					        return JSONResponse(content=generator.model_dump())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert_never(generator)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.get("/v1/models")
 | 
					@router.get("/v1/models")
 | 
				
			||||||
async def show_available_models():
 | 
					async def show_available_models(raw_request: Request):
 | 
				
			||||||
    models = await openai_serving_completion.show_available_models()
 | 
					    models = await completion(raw_request).show_available_models()
 | 
				
			||||||
    return JSONResponse(content=models.model_dump())
 | 
					    return JSONResponse(content=models.model_dump())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -188,45 +312,109 @@ async def show_version():
 | 
				
			||||||
@router.post("/v1/chat/completions")
 | 
					@router.post("/v1/chat/completions")
 | 
				
			||||||
async def create_chat_completion(request: ChatCompletionRequest,
 | 
					async def create_chat_completion(request: ChatCompletionRequest,
 | 
				
			||||||
                                 raw_request: Request):
 | 
					                                 raw_request: Request):
 | 
				
			||||||
    generator = await openai_serving_chat.create_chat_completion(
 | 
					
 | 
				
			||||||
 | 
					    generator = await chat(raw_request).create_chat_completion(
 | 
				
			||||||
        request, raw_request)
 | 
					        request, raw_request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if isinstance(generator, ErrorResponse):
 | 
					    if isinstance(generator, ErrorResponse):
 | 
				
			||||||
        return JSONResponse(content=generator.model_dump(),
 | 
					        return JSONResponse(content=generator.model_dump(),
 | 
				
			||||||
                            status_code=generator.code)
 | 
					                            status_code=generator.code)
 | 
				
			||||||
    if request.stream:
 | 
					
 | 
				
			||||||
        return StreamingResponse(content=generator,
 | 
					    elif isinstance(generator, ChatCompletionResponse):
 | 
				
			||||||
                                 media_type="text/event-stream")
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        assert isinstance(generator, ChatCompletionResponse)
 | 
					 | 
				
			||||||
        return JSONResponse(content=generator.model_dump())
 | 
					        return JSONResponse(content=generator.model_dump())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return StreamingResponse(content=generator, media_type="text/event-stream")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.post("/v1/completions")
 | 
					@router.post("/v1/completions")
 | 
				
			||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
 | 
					async def create_completion(request: CompletionRequest, raw_request: Request):
 | 
				
			||||||
    generator = await openai_serving_completion.create_completion(
 | 
					    generator = await completion(raw_request).create_completion(
 | 
				
			||||||
        request, raw_request)
 | 
					        request, raw_request)
 | 
				
			||||||
    if isinstance(generator, ErrorResponse):
 | 
					    if isinstance(generator, ErrorResponse):
 | 
				
			||||||
        return JSONResponse(content=generator.model_dump(),
 | 
					        return JSONResponse(content=generator.model_dump(),
 | 
				
			||||||
                            status_code=generator.code)
 | 
					                            status_code=generator.code)
 | 
				
			||||||
    if request.stream:
 | 
					    elif isinstance(generator, CompletionResponse):
 | 
				
			||||||
        return StreamingResponse(content=generator,
 | 
					 | 
				
			||||||
                                 media_type="text/event-stream")
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return JSONResponse(content=generator.model_dump())
 | 
					        return JSONResponse(content=generator.model_dump())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return StreamingResponse(content=generator, media_type="text/event-stream")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.post("/v1/embeddings")
 | 
					@router.post("/v1/embeddings")
 | 
				
			||||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 | 
					async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 | 
				
			||||||
    generator = await openai_serving_embedding.create_embedding(
 | 
					    generator = await embedding(raw_request).create_embedding(
 | 
				
			||||||
        request, raw_request)
 | 
					        request, raw_request)
 | 
				
			||||||
    if isinstance(generator, ErrorResponse):
 | 
					    if isinstance(generator, ErrorResponse):
 | 
				
			||||||
        return JSONResponse(content=generator.model_dump(),
 | 
					        return JSONResponse(content=generator.model_dump(),
 | 
				
			||||||
                            status_code=generator.code)
 | 
					                            status_code=generator.code)
 | 
				
			||||||
    else:
 | 
					    elif isinstance(generator, EmbeddingResponse):
 | 
				
			||||||
        return JSONResponse(content=generator.model_dump())
 | 
					        return JSONResponse(content=generator.model_dump())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert_never(generator)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if envs.VLLM_TORCH_PROFILER_DIR:
 | 
				
			||||||
 | 
					    logger.warning(
 | 
				
			||||||
 | 
					        "Torch Profiler is enabled in the API server. This should ONLY be "
 | 
				
			||||||
 | 
					        "used for local development!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @router.post("/start_profile")
 | 
				
			||||||
 | 
					    async def start_profile(raw_request: Request):
 | 
				
			||||||
 | 
					        logger.info("Starting profiler...")
 | 
				
			||||||
 | 
					        await engine_client(raw_request).start_profile()
 | 
				
			||||||
 | 
					        logger.info("Profiler started.")
 | 
				
			||||||
 | 
					        return Response(status_code=200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @router.post("/stop_profile")
 | 
				
			||||||
 | 
					    async def stop_profile(raw_request: Request):
 | 
				
			||||||
 | 
					        logger.info("Stopping profiler...")
 | 
				
			||||||
 | 
					        await engine_client(raw_request).stop_profile()
 | 
				
			||||||
 | 
					        logger.info("Profiler stopped.")
 | 
				
			||||||
 | 
					        return Response(status_code=200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
 | 
				
			||||||
 | 
					    logger.warning(
 | 
				
			||||||
 | 
					        "Lora dynamic loading & unloading is enabled in the API server. "
 | 
				
			||||||
 | 
					        "This should ONLY be used for local development!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @router.post("/v1/load_lora_adapter")
 | 
				
			||||||
 | 
					    async def load_lora_adapter(request: LoadLoraAdapterRequest,
 | 
				
			||||||
 | 
					                                raw_request: Request):
 | 
				
			||||||
 | 
					        response = await chat(raw_request).load_lora_adapter(request)
 | 
				
			||||||
 | 
					        if isinstance(response, ErrorResponse):
 | 
				
			||||||
 | 
					            return JSONResponse(content=response.model_dump(),
 | 
				
			||||||
 | 
					                                status_code=response.code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = await completion(raw_request).load_lora_adapter(request)
 | 
				
			||||||
 | 
					        if isinstance(response, ErrorResponse):
 | 
				
			||||||
 | 
					            return JSONResponse(content=response.model_dump(),
 | 
				
			||||||
 | 
					                                status_code=response.code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return Response(status_code=200, content=response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @router.post("/v1/unload_lora_adapter")
 | 
				
			||||||
 | 
					    async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
 | 
				
			||||||
 | 
					                                  raw_request: Request):
 | 
				
			||||||
 | 
					        response = await chat(raw_request).unload_lora_adapter(request)
 | 
				
			||||||
 | 
					        if isinstance(response, ErrorResponse):
 | 
				
			||||||
 | 
					            return JSONResponse(content=response.model_dump(),
 | 
				
			||||||
 | 
					                                status_code=response.code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = await completion(raw_request).unload_lora_adapter(request)
 | 
				
			||||||
 | 
					        if isinstance(response, ErrorResponse):
 | 
				
			||||||
 | 
					            return JSONResponse(content=response.model_dump(),
 | 
				
			||||||
 | 
					                                status_code=response.code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return Response(status_code=200, content=response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def build_app(args: Namespace) -> FastAPI:
 | 
					def build_app(args: Namespace) -> FastAPI:
 | 
				
			||||||
 | 
					    if args.disable_fastapi_docs:
 | 
				
			||||||
 | 
					        app = FastAPI(openapi_url=None,
 | 
				
			||||||
 | 
					                      docs_url=None,
 | 
				
			||||||
 | 
					                      redoc_url=None,
 | 
				
			||||||
 | 
					                      lifespan=lifespan)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
        app = FastAPI(lifespan=lifespan)
 | 
					        app = FastAPI(lifespan=lifespan)
 | 
				
			||||||
    app.include_router(router)
 | 
					    app.include_router(router)
 | 
				
			||||||
    app.root_path = args.root_path
 | 
					    app.root_path = args.root_path
 | 
				
			||||||
| 
						 | 
					@ -243,7 +431,8 @@ def build_app(args: Namespace) -> FastAPI:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @app.exception_handler(RequestValidationError)
 | 
					    @app.exception_handler(RequestValidationError)
 | 
				
			||||||
    async def validation_exception_handler(_, exc):
 | 
					    async def validation_exception_handler(_, exc):
 | 
				
			||||||
        err = openai_serving_chat.create_error_response(message=str(exc))
 | 
					        chat = app.state.openai_serving_chat
 | 
				
			||||||
 | 
					        err = chat.create_error_response(message=str(exc))
 | 
				
			||||||
        return JSONResponse(err.model_dump(),
 | 
					        return JSONResponse(err.model_dump(),
 | 
				
			||||||
                            status_code=HTTPStatus.BAD_REQUEST)
 | 
					                            status_code=HTTPStatus.BAD_REQUEST)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -275,74 +464,87 @@ def build_app(args: Namespace) -> FastAPI:
 | 
				
			||||||
    return app
 | 
					    return app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def init_app(
 | 
					def init_app_state(
 | 
				
			||||||
    async_engine_client: AsyncEngineClient,
 | 
					    engine_client: EngineClient,
 | 
				
			||||||
 | 
					    model_config: ModelConfig,
 | 
				
			||||||
 | 
					    state: State,
 | 
				
			||||||
    args: Namespace,
 | 
					    args: Namespace,
 | 
				
			||||||
) -> FastAPI:
 | 
					) -> None:
 | 
				
			||||||
    app = build_app(args)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if args.served_model_name is not None:
 | 
					    if args.served_model_name is not None:
 | 
				
			||||||
        served_model_names = args.served_model_name
 | 
					        served_model_names = args.served_model_name
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        served_model_names = [args.model]
 | 
					        served_model_names = [args.model]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_config = await async_engine_client.get_model_config()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if args.disable_log_requests:
 | 
					    if args.disable_log_requests:
 | 
				
			||||||
        request_logger = None
 | 
					        request_logger = None
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        request_logger = RequestLogger(max_log_len=args.max_log_len)
 | 
					        request_logger = RequestLogger(max_log_len=args.max_log_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    global openai_serving_chat
 | 
					    base_model_paths = [
 | 
				
			||||||
    global openai_serving_completion
 | 
					        BaseModelPath(name=name, model_path=args.model)
 | 
				
			||||||
    global openai_serving_embedding
 | 
					        for name in served_model_names
 | 
				
			||||||
    global openai_serving_tokenization
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    openai_serving_chat = OpenAIServingChat(
 | 
					    state.engine_client = engine_client
 | 
				
			||||||
        async_engine_client,
 | 
					    state.log_stats = not args.disable_log_stats
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    state.openai_serving_chat = OpenAIServingChat(
 | 
				
			||||||
 | 
					        engine_client,
 | 
				
			||||||
        model_config,
 | 
					        model_config,
 | 
				
			||||||
        served_model_names,
 | 
					        base_model_paths,
 | 
				
			||||||
        args.response_role,
 | 
					        args.response_role,
 | 
				
			||||||
        lora_modules=args.lora_modules,
 | 
					        lora_modules=args.lora_modules,
 | 
				
			||||||
        prompt_adapters=args.prompt_adapters,
 | 
					        prompt_adapters=args.prompt_adapters,
 | 
				
			||||||
        request_logger=request_logger,
 | 
					        request_logger=request_logger,
 | 
				
			||||||
        chat_template=args.chat_template,
 | 
					        chat_template=args.chat_template,
 | 
				
			||||||
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
					        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
				
			||||||
    )
 | 
					        enable_auto_tools=args.enable_auto_tool_choice,
 | 
				
			||||||
    openai_serving_completion = OpenAIServingCompletion(
 | 
					        tool_parser=args.tool_call_parser)
 | 
				
			||||||
        async_engine_client,
 | 
					    state.openai_serving_completion = OpenAIServingCompletion(
 | 
				
			||||||
 | 
					        engine_client,
 | 
				
			||||||
        model_config,
 | 
					        model_config,
 | 
				
			||||||
        served_model_names,
 | 
					        base_model_paths,
 | 
				
			||||||
        lora_modules=args.lora_modules,
 | 
					        lora_modules=args.lora_modules,
 | 
				
			||||||
        prompt_adapters=args.prompt_adapters,
 | 
					        prompt_adapters=args.prompt_adapters,
 | 
				
			||||||
        request_logger=request_logger,
 | 
					        request_logger=request_logger,
 | 
				
			||||||
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
					        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    openai_serving_embedding = OpenAIServingEmbedding(
 | 
					    state.openai_serving_embedding = OpenAIServingEmbedding(
 | 
				
			||||||
        async_engine_client,
 | 
					        engine_client,
 | 
				
			||||||
        model_config,
 | 
					        model_config,
 | 
				
			||||||
        served_model_names,
 | 
					        base_model_paths,
 | 
				
			||||||
        request_logger=request_logger,
 | 
					        request_logger=request_logger,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    openai_serving_tokenization = OpenAIServingTokenization(
 | 
					    state.openai_serving_tokenization = OpenAIServingTokenization(
 | 
				
			||||||
        async_engine_client,
 | 
					        engine_client,
 | 
				
			||||||
        model_config,
 | 
					        model_config,
 | 
				
			||||||
        served_model_names,
 | 
					        base_model_paths,
 | 
				
			||||||
        lora_modules=args.lora_modules,
 | 
					        lora_modules=args.lora_modules,
 | 
				
			||||||
        request_logger=request_logger,
 | 
					        request_logger=request_logger,
 | 
				
			||||||
        chat_template=args.chat_template,
 | 
					        chat_template=args.chat_template,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    app.root_path = args.root_path
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return app
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def run_server(args, **uvicorn_kwargs) -> None:
 | 
					async def run_server(args, **uvicorn_kwargs) -> None:
 | 
				
			||||||
    logger.info("vLLM API server version %s", VLLM_VERSION)
 | 
					    logger.info("vLLM API server version %s", VLLM_VERSION)
 | 
				
			||||||
    logger.info("args: %s", args)
 | 
					    logger.info("args: %s", args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with build_async_engine_client(args) as async_engine_client:
 | 
					    temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | 
				
			||||||
        app = await init_app(async_engine_client, args)
 | 
					    temp_socket.bind(("", args.port))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def signal_handler(*_) -> None:
 | 
				
			||||||
 | 
					        # Interrupt server on sigterm while initializing
 | 
				
			||||||
 | 
					        raise KeyboardInterrupt("terminated")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    signal.signal(signal.SIGTERM, signal_handler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with build_async_engine_client(args) as engine_client:
 | 
				
			||||||
 | 
					        app = build_app(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model_config = await engine_client.get_model_config()
 | 
				
			||||||
 | 
					        init_app_state(engine_client, model_config, app.state, args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        temp_socket.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        shutdown_task = await serve_http(
 | 
					        shutdown_task = await serve_http(
 | 
				
			||||||
            app,
 | 
					            app,
 | 
				
			||||||
| 
						 | 
					@ -369,4 +571,4 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser = make_arg_parser(parser)
 | 
					    parser = make_arg_parser(parser)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    asyncio.run(run_server(args))
 | 
					    uvloop.run(run_server(args))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,6 +7,7 @@ purposes.
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import ssl
 | 
					import ssl
 | 
				
			||||||
 | 
					from typing import List, Optional, Sequence, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
 | 
					from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
 | 
				
			||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
 | 
					from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
 | 
				
			||||||
| 
						 | 
					@ -16,18 +17,55 @@ from vllm.utils import FlexibleArgumentParser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LoRAParserAction(argparse.Action):
 | 
					class LoRAParserAction(argparse.Action):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, parser, namespace, values, option_string=None):
 | 
					    def __call__(
 | 
				
			||||||
        lora_list = []
 | 
					        self,
 | 
				
			||||||
 | 
					        parser: argparse.ArgumentParser,
 | 
				
			||||||
 | 
					        namespace: argparse.Namespace,
 | 
				
			||||||
 | 
					        values: Optional[Union[str, Sequence[str]]],
 | 
				
			||||||
 | 
					        option_string: Optional[str] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if values is None:
 | 
				
			||||||
 | 
					            values = []
 | 
				
			||||||
 | 
					        if isinstance(values, str):
 | 
				
			||||||
 | 
					            raise TypeError("Expected values to be a list")  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lora_list: List[LoRAModulePath] = []
 | 
				
			||||||
        for item in values:
 | 
					        for item in values:
 | 
				
			||||||
 | 
					            if item in [None, '']:  # Skip if item is None or empty string
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            if '=' in item and ',' not in item:  # Old format: name=path
 | 
				
			||||||
                name, path = item.split('=')
 | 
					                name, path = item.split('=')
 | 
				
			||||||
                lora_list.append(LoRAModulePath(name, path))
 | 
					                lora_list.append(LoRAModulePath(name, path))
 | 
				
			||||||
 | 
					            else:  # Assume JSON format
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    lora_dict = json.loads(item)
 | 
				
			||||||
 | 
					                    lora = LoRAModulePath(**lora_dict)
 | 
				
			||||||
 | 
					                    lora_list.append(lora)
 | 
				
			||||||
 | 
					                except json.JSONDecodeError:
 | 
				
			||||||
 | 
					                    parser.error(
 | 
				
			||||||
 | 
					                        f"Invalid JSON format for --lora-modules: {item}")
 | 
				
			||||||
 | 
					                except TypeError as e:
 | 
				
			||||||
 | 
					                    parser.error(
 | 
				
			||||||
 | 
					                        f"Invalid fields for --lora-modules: {item} - {str(e)}"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
        setattr(namespace, self.dest, lora_list)
 | 
					        setattr(namespace, self.dest, lora_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PromptAdapterParserAction(argparse.Action):
 | 
					class PromptAdapterParserAction(argparse.Action):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, parser, namespace, values, option_string=None):
 | 
					    def __call__(
 | 
				
			||||||
        adapter_list = []
 | 
					        self,
 | 
				
			||||||
 | 
					        parser: argparse.ArgumentParser,
 | 
				
			||||||
 | 
					        namespace: argparse.Namespace,
 | 
				
			||||||
 | 
					        values: Optional[Union[str, Sequence[str]]],
 | 
				
			||||||
 | 
					        option_string: Optional[str] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if values is None:
 | 
				
			||||||
 | 
					            values = []
 | 
				
			||||||
 | 
					        if isinstance(values, str):
 | 
				
			||||||
 | 
					            raise TypeError("Expected values to be a list")  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        adapter_list: List[PromptAdapterPath] = []
 | 
				
			||||||
        for item in values:
 | 
					        for item in values:
 | 
				
			||||||
            name, path = item.split('=')
 | 
					            name, path = item.split('=')
 | 
				
			||||||
            adapter_list.append(PromptAdapterPath(name, path))
 | 
					            adapter_list.append(PromptAdapterPath(name, path))
 | 
				
			||||||
| 
						 | 
					@ -72,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
        nargs='+',
 | 
					        nargs='+',
 | 
				
			||||||
        action=LoRAParserAction,
 | 
					        action=LoRAParserAction,
 | 
				
			||||||
        help="LoRA module configurations in the format name=path. "
 | 
					        help="LoRA module configurations in either 'name=path' format"
 | 
				
			||||||
        "Multiple modules can be specified.")
 | 
					        "or JSON format. "
 | 
				
			||||||
 | 
					        "Example (old format): 'name=path' "
 | 
				
			||||||
 | 
					        "Example (new format): "
 | 
				
			||||||
 | 
					        "'{\"name\": \"name\", \"local_path\": \"path\", "
 | 
				
			||||||
 | 
					        "\"base_model_name\": \"id\"}'")
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--prompt-adapters",
 | 
					        "--prompt-adapters",
 | 
				
			||||||
        type=nullable_str,
 | 
					        type=nullable_str,
 | 
				
			||||||
| 
						 | 
					@ -91,8 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
 | 
				
			||||||
    parser.add_argument("--response-role",
 | 
					    parser.add_argument("--response-role",
 | 
				
			||||||
                        type=nullable_str,
 | 
					                        type=nullable_str,
 | 
				
			||||||
                        default="assistant",
 | 
					                        default="assistant",
 | 
				
			||||||
                        help="The role name to return if "
 | 
					                        help="The role name to return if `request.add_generation_prompt=true`.")
 | 
				
			||||||
                        "`request.add_generation_prompt=true`.")
 | 
					 | 
				
			||||||
    parser.add_argument("--ssl-keyfile",
 | 
					    parser.add_argument("--ssl-keyfile",
 | 
				
			||||||
                        type=nullable_str,
 | 
					                        type=nullable_str,
 | 
				
			||||||
                        default=None,
 | 
					                        default=None,
 | 
				
			||||||
| 
						 | 
					@ -139,6 +180,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
 | 
				
			||||||
        action="store_true",
 | 
					        action="store_true",
 | 
				
			||||||
        help="If specified, will run the OpenAI frontend server in the same "
 | 
					        help="If specified, will run the OpenAI frontend server in the same "
 | 
				
			||||||
        "process as the model serving engine.")
 | 
					        "process as the model serving engine.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--enable-auto-tool-choice",
 | 
				
			||||||
 | 
					        action="store_true",
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        help="Enable auto tool choice for supported models. Use --tool-call-parser"
 | 
				
			||||||
 | 
					        "to specify which parser to use")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--tool-call-parser",
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        choices=["mistral", "hermes"],
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        help="Select the tool call parser depending on the model that you're using."
 | 
				
			||||||
 | 
					        " This is used to parse the model-generated tool call into OpenAI API "
 | 
				
			||||||
 | 
					        "format. Required for --enable-auto-tool-choice.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--load-in-low-bit",
 | 
					        "--load-in-low-bit",
 | 
				
			||||||
        type=str,
 | 
					        type=str,
 | 
				
			||||||
| 
						 | 
					@ -154,6 +212,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
 | 
				
			||||||
                        'ID numbers being printed in log.'
 | 
					                        'ID numbers being printed in log.'
 | 
				
			||||||
                        '\n\nDefault: Unlimited')
 | 
					                        '\n\nDefault: Unlimited')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--disable-fastapi-docs",
 | 
				
			||||||
 | 
					        action='store_true',
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return parser
 | 
					    return parser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,221 +0,0 @@
 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import signal
 | 
					 | 
				
			||||||
from typing import Any, Coroutine
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import cloudpickle
 | 
					 | 
				
			||||||
import zmq
 | 
					 | 
				
			||||||
import zmq.asyncio
 | 
					 | 
				
			||||||
from typing_extensions import Never
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from vllm import AsyncEngineArgs
 | 
					 | 
				
			||||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
 | 
					 | 
				
			||||||
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
 | 
					 | 
				
			||||||
                                         RPCGenerateRequest, RPCUtilityRequest)
 | 
					 | 
				
			||||||
from vllm.logger import init_logger
 | 
					 | 
				
			||||||
from vllm.usage.usage_lib import UsageContext
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = init_logger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AsyncEngineRPCServer:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, async_engine_args: AsyncEngineArgs,
 | 
					 | 
				
			||||||
                 usage_context: UsageContext, port: int, load_in_low_bit: str):
 | 
					 | 
				
			||||||
        # Initialize engine first.
 | 
					 | 
				
			||||||
        self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
 | 
					 | 
				
			||||||
                                                      usage_context=usage_context,
 | 
					 | 
				
			||||||
                                                      load_in_low_bit=load_in_low_bit)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Initialize context.
 | 
					 | 
				
			||||||
        self.context = zmq.asyncio.Context()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Init socket for readiness state.
 | 
					 | 
				
			||||||
        self.socket = self.context.socket(zmq.constants.ROUTER)
 | 
					 | 
				
			||||||
        # Note numeric form of localhost should be used for zmq bind(),
 | 
					 | 
				
			||||||
        # see https://stackoverflow.com/a/8958414
 | 
					 | 
				
			||||||
        self.socket.bind(f"tcp://127.0.0.1:{port}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def cleanup(self):
 | 
					 | 
				
			||||||
        """Cleanup all resources."""
 | 
					 | 
				
			||||||
        self.socket.close()
 | 
					 | 
				
			||||||
        self.context.destroy()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get_model_config(self, identity):
 | 
					 | 
				
			||||||
        """Send the ModelConfig"""
 | 
					 | 
				
			||||||
        model_config = await self.engine.get_model_config()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.socket.send_multipart(
 | 
					 | 
				
			||||||
            [identity, cloudpickle.dumps(model_config)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get_decoding_config(self, identity):
 | 
					 | 
				
			||||||
        """Send the DecodingConfig"""
 | 
					 | 
				
			||||||
        decoding_config = await self.engine.get_decoding_config()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.socket.send_multipart(
 | 
					 | 
				
			||||||
            [identity, cloudpickle.dumps(decoding_config)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get_lora_config(self, identity):
 | 
					 | 
				
			||||||
        lora_config = await self.engine.get_lora_config()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.socket.send_multipart(
 | 
					 | 
				
			||||||
            [identity, cloudpickle.dumps(lora_config)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get_scheduler_config(self, identity):
 | 
					 | 
				
			||||||
        """Send the SchedulerConfig"""
 | 
					 | 
				
			||||||
        parallel_config = await self.engine.get_scheduler_config()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.socket.send_multipart(
 | 
					 | 
				
			||||||
            [identity, cloudpickle.dumps(parallel_config)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get_parallel_config(self, identity):
 | 
					 | 
				
			||||||
        """Send the ParallelConfig"""
 | 
					 | 
				
			||||||
        parallel_config = await self.engine.get_parallel_config()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.socket.send_multipart(
 | 
					 | 
				
			||||||
            [identity, cloudpickle.dumps(parallel_config)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def is_tracing_enabled(self, identity):
 | 
					 | 
				
			||||||
        """Send the is_tracing_enabled flag"""
 | 
					 | 
				
			||||||
        tracing_flag = await self.engine.is_tracing_enabled()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.socket.send_multipart(
 | 
					 | 
				
			||||||
            [identity, cloudpickle.dumps(tracing_flag)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def do_log_stats(self, identity):
 | 
					 | 
				
			||||||
        """Log stats and confirm success."""
 | 
					 | 
				
			||||||
        await self.engine.do_log_stats()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.socket.send_multipart([
 | 
					 | 
				
			||||||
            identity,
 | 
					 | 
				
			||||||
            cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
 | 
					 | 
				
			||||||
        ])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def is_server_ready(self, identity):
 | 
					 | 
				
			||||||
        """Notify the client that we are ready."""
 | 
					 | 
				
			||||||
        await self.socket.send_multipart([
 | 
					 | 
				
			||||||
            identity,
 | 
					 | 
				
			||||||
            cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
 | 
					 | 
				
			||||||
        ])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def abort(self, identity, request: RPCAbortRequest):
 | 
					 | 
				
			||||||
        """Abort request and notify the client of success."""
 | 
					 | 
				
			||||||
        # Abort the request in the llm engine.
 | 
					 | 
				
			||||||
        await self.engine.abort(request.request_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Send confirmation to the client.
 | 
					 | 
				
			||||||
        await self.socket.send_multipart([
 | 
					 | 
				
			||||||
            identity,
 | 
					 | 
				
			||||||
            cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
 | 
					 | 
				
			||||||
        ])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def generate(self, identity, generate_request: RPCGenerateRequest):
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            results_generator = self.engine.generate(
 | 
					 | 
				
			||||||
                generate_request.inputs,
 | 
					 | 
				
			||||||
                sampling_params=generate_request.sampling_params,
 | 
					 | 
				
			||||||
                request_id=generate_request.request_id,
 | 
					 | 
				
			||||||
                lora_request=generate_request.lora_request,
 | 
					 | 
				
			||||||
                trace_headers=generate_request.trace_headers,
 | 
					 | 
				
			||||||
                prompt_adapter_request=generate_request.prompt_adapter_request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            async for request_output in results_generator:
 | 
					 | 
				
			||||||
                await self.socket.send_multipart(
 | 
					 | 
				
			||||||
                    [identity, cloudpickle.dumps(request_output)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except Exception as e:
 | 
					 | 
				
			||||||
            # Notify client of all failures
 | 
					 | 
				
			||||||
            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def check_health(self, identity):
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            await self.engine.check_health()
 | 
					 | 
				
			||||||
            await self.socket.send_multipart(
 | 
					 | 
				
			||||||
                [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
 | 
					 | 
				
			||||||
        except Exception as e:
 | 
					 | 
				
			||||||
            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _make_handler_coro(self, identity,
 | 
					 | 
				
			||||||
                           message) -> Coroutine[Any, Any, Never]:
 | 
					 | 
				
			||||||
        """Route the zmq message to the handler coroutine."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request = cloudpickle.loads(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(request, RPCGenerateRequest):
 | 
					 | 
				
			||||||
            return self.generate(identity, request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif isinstance(request, RPCAbortRequest):
 | 
					 | 
				
			||||||
            return self.abort(identity, request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif isinstance(request, RPCUtilityRequest):
 | 
					 | 
				
			||||||
            if request == RPCUtilityRequest.GET_MODEL_CONFIG:
 | 
					 | 
				
			||||||
                return self.get_model_config(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
 | 
					 | 
				
			||||||
                return self.get_parallel_config(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
 | 
					 | 
				
			||||||
                return self.get_decoding_config(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
 | 
					 | 
				
			||||||
                return self.get_scheduler_config(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.GET_LORA_CONFIG:
 | 
					 | 
				
			||||||
                return self.get_lora_config(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.DO_LOG_STATS:
 | 
					 | 
				
			||||||
                return self.do_log_stats(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.IS_SERVER_READY:
 | 
					 | 
				
			||||||
                return self.is_server_ready(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.CHECK_HEALTH:
 | 
					 | 
				
			||||||
                return self.check_health(identity)
 | 
					 | 
				
			||||||
            elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
 | 
					 | 
				
			||||||
                return self.is_tracing_enabled(identity)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise ValueError(f"Unknown RPCUtilityRequest type: {request}")  # noqa
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise ValueError(f"Unknown RPCRequest type: {request}")  # noqa
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def run_server_loop(self):
 | 
					 | 
				
			||||||
        """Inner RPC Server Loop"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        running_tasks = set()
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            # Wait for a request.
 | 
					 | 
				
			||||||
            identity, message = await self.socket.recv_multipart()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Process the request async.
 | 
					 | 
				
			||||||
            task = asyncio.create_task(
 | 
					 | 
				
			||||||
                self._make_handler_coro(identity, message))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # We need to keep around a strong reference to the task,
 | 
					 | 
				
			||||||
            # to avoid the task disappearing mid-execution as running tasks
 | 
					 | 
				
			||||||
            # can be GC'ed. Below is a common "fire-and-forget" tasks
 | 
					 | 
				
			||||||
            # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
 | 
					 | 
				
			||||||
            running_tasks.add(task)
 | 
					 | 
				
			||||||
            task.add_done_callback(running_tasks.discard)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def run_server(server: AsyncEngineRPCServer):
 | 
					 | 
				
			||||||
    # Put the server task into the asyncio loop.
 | 
					 | 
				
			||||||
    loop = asyncio.get_running_loop()
 | 
					 | 
				
			||||||
    server_task = loop.create_task(server.run_server_loop())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Interruption handling.
 | 
					 | 
				
			||||||
    def signal_handler() -> None:
 | 
					 | 
				
			||||||
        # Kill the server on interrupt / terminate
 | 
					 | 
				
			||||||
        server_task.cancel()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    loop.add_signal_handler(signal.SIGINT, signal_handler)
 | 
					 | 
				
			||||||
    loop.add_signal_handler(signal.SIGTERM, signal_handler)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        await server_task
 | 
					 | 
				
			||||||
    except asyncio.CancelledError:
 | 
					 | 
				
			||||||
        logger.info("vLLM ZMQ RPC Server was interrupted.")
 | 
					 | 
				
			||||||
    finally:
 | 
					 | 
				
			||||||
        # Clean up all resources.
 | 
					 | 
				
			||||||
        server.cleanup()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def run_rpc_server(async_engine_args: AsyncEngineArgs,
 | 
					 | 
				
			||||||
                   usage_context: UsageContext, port: int, load_in_low_bit: str):
 | 
					 | 
				
			||||||
    server = AsyncEngineRPCServer(async_engine_args, usage_context, port, load_in_low_bit)
 | 
					 | 
				
			||||||
    asyncio.run(run_server(server))
 | 
					 | 
				
			||||||
| 
						 | 
					@ -75,14 +75,13 @@ def get_load_function(low_bit):
 | 
				
			||||||
        _model_sample_convert()
 | 
					        _model_sample_convert()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # from vllm.utils import measure_device_memory
 | 
					        # from vllm.utils import measure_device_memory
 | 
				
			||||||
        from vllm.utils import CudaMemoryProfiler
 | 
					        from vllm.utils import DeviceMemoryProfiler
 | 
				
			||||||
        with CudaMemoryProfiler() as m:
 | 
					        with DeviceMemoryProfiler() as m:
 | 
				
			||||||
            self.model = get_model(
 | 
					            self.model = get_model(
 | 
				
			||||||
                model_config=self.model_config,
 | 
					                model_config=self.model_config,
 | 
				
			||||||
                device_config=DeviceConfig("cpu"),
 | 
					                device_config=DeviceConfig("cpu"),
 | 
				
			||||||
                load_config=self.load_config,
 | 
					                load_config=self.load_config,
 | 
				
			||||||
                lora_config=self.lora_config,
 | 
					                lora_config=self.lora_config,
 | 
				
			||||||
                multimodal_config=self.multimodal_config,
 | 
					 | 
				
			||||||
                parallel_config=self.parallel_config,
 | 
					                parallel_config=self.parallel_config,
 | 
				
			||||||
                scheduler_config=self.scheduler_config,
 | 
					                scheduler_config=self.scheduler_config,
 | 
				
			||||||
                cache_config=self.cache_config,
 | 
					                cache_config=self.cache_config,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue