Add tensor parallel for vLLM (#10879)
* initial * test initial tp * initial sup * fix format * fix * fix
This commit is contained in:
parent
d058f2b403
commit
990535b1cf
4 changed files with 507 additions and 10 deletions
|
|
@ -117,6 +117,10 @@ def is_linear_module(module):
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_group,
|
||||||
|
get_tensor_model_parallel_world_size
|
||||||
|
)
|
||||||
VLLM_LINEAR_LIST = [
|
VLLM_LINEAR_LIST = [
|
||||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||||
]
|
]
|
||||||
|
|
@ -125,6 +129,12 @@ def is_linear_module(module):
|
||||||
out_features = module.output_size
|
out_features = module.output_size
|
||||||
result = True
|
result = True
|
||||||
mp_group = None
|
mp_group = None
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if isinstance(module, RowParallelLinear) and tp_size >= 2:
|
||||||
|
mp_group = get_tensor_model_parallel_group()
|
||||||
|
in_features = module.input_size_per_partition
|
||||||
|
elif isinstance(module, ColumnParallelLinear) and tp_size >= 2:
|
||||||
|
out_features = module.output_size_per_partition
|
||||||
else:
|
else:
|
||||||
result = False
|
result = False
|
||||||
elif is_gptq_linear(module):
|
elif is_gptq_linear(module):
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ from typing import Optional, TypeVar, Union, overload
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
from operator import mul
|
from operator import mul
|
||||||
|
|
@ -52,6 +53,7 @@ from functools import reduce
|
||||||
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
|
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
|
||||||
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
|
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
|
||||||
get_ipex_version
|
get_ipex_version
|
||||||
|
from ipex_llm.transformers.convert import is_deepspeed_available, is_vllm_available
|
||||||
|
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
|
||||||
|
|
@ -702,8 +704,14 @@ class LowBitLinear(nn.Linear):
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
|
# FIXME: the user may install both vllm and deepspeed
|
||||||
|
if is_deepspeed_available():
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
|
elif is_vllm_available():
|
||||||
|
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
result += self.bias
|
result += self.bias
|
||||||
else:
|
else:
|
||||||
|
|
@ -729,6 +737,7 @@ class LowBitLinear(nn.Linear):
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
# allreduce to combine partial results and add bias if necessary
|
# allreduce to combine partial results and add bias if necessary
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
|
# TODO: implement for CPU logic for vLLM tp
|
||||||
# deepspeed distibuted mode
|
# deepspeed distibuted mode
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
|
|
@ -780,8 +789,13 @@ class FP16Linear(nn.Linear):
|
||||||
self.weight_type = 2
|
self.weight_type = 2
|
||||||
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
|
if is_deepspeed_available():
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
|
elif is_vllm_available():
|
||||||
|
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
if self.in_len == 4096 and self.weight_type != 3 or \
|
if self.in_len == 4096 and self.weight_type != 3 or \
|
||||||
|
|
@ -817,8 +831,13 @@ class FP16Linear(nn.Linear):
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
|
if is_deepspeed_available():
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
|
elif is_vllm_available():
|
||||||
|
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
result += self.bias
|
result += self.bias
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,9 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||||
parallel_config = engine_configs[2]
|
parallel_config = engine_configs[2]
|
||||||
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
||||||
initialize_ray_cluster(parallel_config)
|
initialize_ray_cluster(parallel_config)
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
# from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||||
executor_class = RayGPUExecutorAsync
|
from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class_async
|
||||||
|
executor_class = get_gpu_executor_class_async(load_in_low_bit)
|
||||||
else:
|
else:
|
||||||
invalidInputError(parallel_config.world_size == 1, (
|
invalidInputError(parallel_config.world_size == 1, (
|
||||||
"Ray is required if parallel_config.world_size > 1."))
|
"Ray is required if parallel_config.world_size > 1."))
|
||||||
|
|
@ -130,8 +131,9 @@ class IPEXLLMLLMEngine(LLMEngine):
|
||||||
# Initialize the cluster and specify the executor class.
|
# Initialize the cluster and specify the executor class.
|
||||||
if parallel_config.worker_use_ray:
|
if parallel_config.worker_use_ray:
|
||||||
initialize_ray_cluster(parallel_config)
|
initialize_ray_cluster(parallel_config)
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
# from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||||
executor_class = RayGPUExecutor
|
from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class
|
||||||
|
executor_class = get_gpu_executor_class(load_in_low_bit)
|
||||||
else:
|
else:
|
||||||
invalidInputError(parallel_config.world_size == 1,
|
invalidInputError(parallel_config.world_size == 1,
|
||||||
"Ray is required if parallel_config.world_size > 1.")
|
"Ray is required if parallel_config.world_size > 1.")
|
||||||
|
|
|
||||||
466
python/llm/src/ipex_llm/vllm/ipex_llm_gpu_executor.py
Normal file
466
python/llm/src/ipex_llm/vllm/ipex_llm_gpu_executor.py
Normal file
|
|
@ -0,0 +1,466 @@
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
from collections import defaultdict
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import importlib
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
|
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||||
|
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||||
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
|
from vllm.executor.utils import check_block_size_valid
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.utils import (set_cuda_visible_devices, get_ip, get_open_port,
|
||||||
|
get_distributed_init_method, make_async)
|
||||||
|
import functools
|
||||||
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
if ray is not None:
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# A map between the device type (in device config) to its worker module.
|
||||||
|
DEVICE_TO_WORKER_MODULE_MAP = {
|
||||||
|
"cuda": "vllm.worker.worker",
|
||||||
|
"xpu": "vllm.worker.worker",
|
||||||
|
"neuron": "vllm.worker.neuron_worker",
|
||||||
|
}
|
||||||
|
|
||||||
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
|
# which optimizes the control plane overhead.
|
||||||
|
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
|
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||||
|
|
||||||
|
|
||||||
|
class IPEXLLMGPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
load_in_low_bit: str,
|
||||||
|
) -> None:
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.load_in_low_bit = load_in_low_bit
|
||||||
|
|
||||||
|
invalidInputError(self.parallel_config.worker_use_ray,
|
||||||
|
"worker_use_ray is False, but use ray worker")
|
||||||
|
placement_group = self.parallel_config.placement_group
|
||||||
|
|
||||||
|
# Disable Ray usage stats collection.
|
||||||
|
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||||
|
if ray_usage != "1":
|
||||||
|
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||||
|
|
||||||
|
# Create the parallel GPU workers.
|
||||||
|
self._init_workers_ray(placement_group)
|
||||||
|
|
||||||
|
# Profile the memory usage and initialize the cache.
|
||||||
|
self._init_cache()
|
||||||
|
|
||||||
|
self.forward_dag = None
|
||||||
|
if USE_RAY_COMPILED_DAG:
|
||||||
|
self.forward_dag = self._compiled_ray_dag()
|
||||||
|
|
||||||
|
def _dispatch_worker(self):
|
||||||
|
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
||||||
|
self.device_config.device_type]
|
||||||
|
imported_worker = importlib.import_module(worker_module)
|
||||||
|
Worker = imported_worker.Worker
|
||||||
|
return Worker
|
||||||
|
|
||||||
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
|
**ray_remote_kwargs):
|
||||||
|
if self.parallel_config.tensor_parallel_size == 1:
|
||||||
|
# For single GPU case, we use a ray worker with constrained memory.
|
||||||
|
num_gpus = self.cache_config.gpu_memory_utilization
|
||||||
|
else:
|
||||||
|
# Otherwise, the ray workers are allocated with a full GPU.
|
||||||
|
num_gpus = 1
|
||||||
|
|
||||||
|
# The driver dummy worker does not actually use any resources.
|
||||||
|
# It holds the resource for the driver worker.
|
||||||
|
self.driver_dummy_worker: RayWorkerVllm = None
|
||||||
|
# The remaining workers are the actual ray actors.
|
||||||
|
self.workers: List[RayWorkerVllm] = []
|
||||||
|
|
||||||
|
# Create the workers.
|
||||||
|
driver_ip = get_ip()
|
||||||
|
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||||
|
if not bundle.get("GPU", 0):
|
||||||
|
continue
|
||||||
|
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
placement_group_bundle_index=bundle_id,
|
||||||
|
)
|
||||||
|
worker = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
scheduling_strategy=scheduling_strategy,
|
||||||
|
**ray_remote_kwargs,
|
||||||
|
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
|
||||||
|
|
||||||
|
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||||
|
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||||
|
# If the worker is on the same node as the driver, we use it
|
||||||
|
# as the resource holder for the driver process.
|
||||||
|
self.driver_dummy_worker = worker
|
||||||
|
else:
|
||||||
|
# Else, added to the list of workers.
|
||||||
|
self.workers.append(worker)
|
||||||
|
|
||||||
|
if self.driver_dummy_worker is None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||||
|
"adjusting the Ray placement group or running the driver on a "
|
||||||
|
"GPU node.")
|
||||||
|
# Get the set of GPU IDs used on each node.
|
||||||
|
driver_node_id, driver_gpu_ids = ray.get(
|
||||||
|
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
|
||||||
|
worker_node_and_gpu_ids = ray.get(
|
||||||
|
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
|
||||||
|
|
||||||
|
node_workers = defaultdict(list)
|
||||||
|
node_gpus = defaultdict(list)
|
||||||
|
|
||||||
|
node_workers[driver_node_id].append(0)
|
||||||
|
node_gpus[driver_node_id].extend(driver_gpu_ids)
|
||||||
|
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
|
||||||
|
start=1):
|
||||||
|
node_workers[node_id].append(i)
|
||||||
|
node_gpus[node_id].extend(gpu_ids)
|
||||||
|
for node_id, gpu_ids in node_gpus.items():
|
||||||
|
node_gpus[node_id] = sorted(gpu_ids)
|
||||||
|
|
||||||
|
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
|
||||||
|
set_cuda_visible_devices(node_gpus[driver_node_id])
|
||||||
|
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
||||||
|
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
||||||
|
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
driver_ip, get_open_port())
|
||||||
|
|
||||||
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
|
Worker = self._dispatch_worker()
|
||||||
|
|
||||||
|
model_config = copy.deepcopy(self.model_config)
|
||||||
|
parallel_config = copy.deepcopy(self.parallel_config)
|
||||||
|
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||||
|
device_config = copy.deepcopy(self.device_config)
|
||||||
|
lora_config = copy.deepcopy(self.lora_config)
|
||||||
|
kv_cache_dtype = self.cache_config.cache_dtype
|
||||||
|
|
||||||
|
# Initialize the actual workers with the Worker class.
|
||||||
|
for rank, (worker, (node_id, _)) in enumerate(
|
||||||
|
zip(self.workers, worker_node_and_gpu_ids),
|
||||||
|
start=1,
|
||||||
|
):
|
||||||
|
local_rank = node_workers[node_id].index(rank)
|
||||||
|
from ipex_llm.vllm.model_convert import _ipex_llm_convert
|
||||||
|
|
||||||
|
def create_worker_function(rank, local_rank, load_in_low_bit):
|
||||||
|
def worker_function():
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
return Worker(
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
local_rank,
|
||||||
|
rank,
|
||||||
|
distributed_init_method,
|
||||||
|
lora_config=lora_config,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
)
|
||||||
|
return worker_function
|
||||||
|
worker.init_worker.remote(create_worker_function(rank,
|
||||||
|
local_rank,
|
||||||
|
self.load_in_low_bit))
|
||||||
|
|
||||||
|
# Initialize the driver worker with the Worker class.
|
||||||
|
driver_rank = 0
|
||||||
|
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||||
|
self.driver_worker = Worker(
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config,
|
||||||
|
self.scheduler_config,
|
||||||
|
self.device_config,
|
||||||
|
driver_local_rank,
|
||||||
|
driver_rank,
|
||||||
|
distributed_init_method,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
is_driver_worker=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We want to apply patch here before we loading the model
|
||||||
|
# FIXME(woosuk): We are not properly initializing cupy NCCL when
|
||||||
|
# we have multiple nodes.
|
||||||
|
self._run_workers("init_model",
|
||||||
|
cupy_port=get_open_port()
|
||||||
|
if not model_config.enforce_eager else None)
|
||||||
|
self._run_workers(
|
||||||
|
"load_model",
|
||||||
|
max_concurrent_workers=self.parallel_config.
|
||||||
|
max_parallel_loading_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_cache(self) -> None:
|
||||||
|
"""Profiles the memory usage and initializes the KV cache.
|
||||||
|
|
||||||
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
|
that can be allocated with the remaining free memory.
|
||||||
|
More details can be found in the
|
||||||
|
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
|
||||||
|
from class :class:`~vllm.worker.Worker`.
|
||||||
|
|
||||||
|
Afterwards, as there may be multiple workers,
|
||||||
|
we take the minimum number of blocks across all workers
|
||||||
|
to ensure this can be applied to all of them.
|
||||||
|
|
||||||
|
Finally, the engine will initialize the KV cache
|
||||||
|
with the calculated number of blocks.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
|
"""
|
||||||
|
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||||
|
num_blocks = self._run_workers(
|
||||||
|
"profile_num_available_blocks",
|
||||||
|
block_size=self.cache_config.block_size,
|
||||||
|
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||||
|
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||||
|
cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since we use a shared centralized controller, we take the minimum
|
||||||
|
# number of blocks across all workers to make sure all the memory
|
||||||
|
# operators can be applied to all workers.
|
||||||
|
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||||
|
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||||
|
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||||
|
f"# CPU blocks: {num_cpu_blocks}")
|
||||||
|
|
||||||
|
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
|
||||||
|
self.model_config.max_model_len)
|
||||||
|
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
# Initialize the cache.
|
||||||
|
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||||
|
# Warm up the model. This includes capturing the model into CUDA graph
|
||||||
|
# if enforce_eager is False.
|
||||||
|
self._run_workers("warm_up_model")
|
||||||
|
|
||||||
|
def execute_model(self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||||
|
all_outputs = self._run_workers(
|
||||||
|
"execute_model",
|
||||||
|
driver_kwargs={
|
||||||
|
"seq_group_metadata_list": seq_group_metadata_list,
|
||||||
|
"blocks_to_swap_in": blocks_to_swap_in,
|
||||||
|
"blocks_to_swap_out": blocks_to_swap_out,
|
||||||
|
"blocks_to_copy": blocks_to_copy,
|
||||||
|
},
|
||||||
|
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||||
|
|
||||||
|
# Only the driver worker returns the sampling results.
|
||||||
|
output = all_outputs[0]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
invalidInputError(lora_request.lora_int_id > 0,
|
||||||
|
"lora_id must be greater than 0.")
|
||||||
|
return self._run_workers(
|
||||||
|
"add_lora",
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
invalidInputError(lora_id > 0, "lora_id must be greater than 0.")
|
||||||
|
return self._run_workers(
|
||||||
|
"remove_lora",
|
||||||
|
lora_id=lora_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_loras(self) -> List[int]:
|
||||||
|
return self._run_workers("list_loras")
|
||||||
|
|
||||||
|
def _run_workers(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*args,
|
||||||
|
driver_args: Optional[List[Any]]=None,
|
||||||
|
driver_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
|
max_concurrent_workers: Optional[int] = None,
|
||||||
|
use_ray_compiled_dag: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Runs the given method on all workers."""
|
||||||
|
|
||||||
|
if max_concurrent_workers:
|
||||||
|
invalidInputError(False,
|
||||||
|
"max_concurrent_workers is not supported yet.")
|
||||||
|
|
||||||
|
if use_ray_compiled_dag:
|
||||||
|
# Right now, compiled DAG can only accept a single
|
||||||
|
# input. TODO(sang): Fix it.
|
||||||
|
output_channels = self.forward_dag.execute(1)
|
||||||
|
else:
|
||||||
|
# Start the ray workers first.
|
||||||
|
ray_worker_outputs = [
|
||||||
|
worker.execute_method.remote(method, *args, **kwargs)
|
||||||
|
for worker in self.workers
|
||||||
|
]
|
||||||
|
|
||||||
|
if driver_args is None:
|
||||||
|
driver_args = args
|
||||||
|
if driver_kwargs is None:
|
||||||
|
driver_kwargs = kwargs
|
||||||
|
|
||||||
|
# Start the driver worker after all the ray workers.
|
||||||
|
driver_worker_output = getattr(self.driver_worker,
|
||||||
|
method)(*driver_args, **driver_kwargs)
|
||||||
|
|
||||||
|
# Get the results of the ray workers.
|
||||||
|
if self.workers:
|
||||||
|
if use_ray_compiled_dag:
|
||||||
|
try:
|
||||||
|
ray_worker_outputs = [
|
||||||
|
pickle.loads(chan.begin_read())
|
||||||
|
for chan in output_channels
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
# Has to call end_read in order to reuse the DAG.
|
||||||
|
for chan in output_channels:
|
||||||
|
chan.end_read()
|
||||||
|
else:
|
||||||
|
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||||
|
|
||||||
|
return [driver_worker_output] + ray_worker_outputs
|
||||||
|
|
||||||
|
def _compiled_ray_dag(self):
|
||||||
|
import pkg_resources
|
||||||
|
required_version = "2.9"
|
||||||
|
current_version = pkg_resources.get_distribution("ray").version
|
||||||
|
if current_version < required_version:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Ray version {required_version} or greater is "
|
||||||
|
f"required, but found {current_version}")
|
||||||
|
|
||||||
|
from ray.dag import MultiOutputNode, InputNode
|
||||||
|
invalidInputError(self.parallel_config.worker_use_ray,
|
||||||
|
"Use ray worker, but worker_use_ray is False")
|
||||||
|
|
||||||
|
# Right now, compiled DAG requires at least 1 arg. We send
|
||||||
|
# a dummy value for now. It will be fixed soon.
|
||||||
|
with InputNode() as input_data:
|
||||||
|
forward_dag = MultiOutputNode([
|
||||||
|
worker.execute_model_compiled_dag_remote.bind(input_data)
|
||||||
|
for worker in self.workers
|
||||||
|
])
|
||||||
|
return forward_dag.experimental_compile()
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
"""Raises an error if engine is unhealthy."""
|
||||||
|
self._check_if_any_actor_is_dead()
|
||||||
|
|
||||||
|
def _check_if_any_actor_is_dead(self):
|
||||||
|
if not self.workers:
|
||||||
|
return
|
||||||
|
|
||||||
|
dead_actors = []
|
||||||
|
for actor in self.workers:
|
||||||
|
actor_state = ray.state.actors(actor._ray_actor_id.hex())
|
||||||
|
if actor_state["State"] == "DEAD":
|
||||||
|
dead_actors.append(actor)
|
||||||
|
if dead_actors:
|
||||||
|
invalidInputError("At least one Worker is dead. "
|
||||||
|
f"Dead Workers: {dead_actors}. ")
|
||||||
|
|
||||||
|
|
||||||
|
class IPEXLLMGPUExecutorAsync(IPEXLLMGPUExecutor, ExecutorAsyncBase):
|
||||||
|
|
||||||
|
async def _run_workers_async(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*args,
|
||||||
|
driver_args: Optional[List[Any]]=None,
|
||||||
|
driver_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Runs the given method on all workers."""
|
||||||
|
coros = []
|
||||||
|
|
||||||
|
if driver_args is None:
|
||||||
|
driver_args = args
|
||||||
|
if driver_kwargs is None:
|
||||||
|
driver_kwargs = kwargs
|
||||||
|
|
||||||
|
# Run the driver worker asynchronously.
|
||||||
|
driver_executor = make_async(getattr(self.driver_worker, method))
|
||||||
|
coros.append(driver_executor(*driver_args, **driver_kwargs))
|
||||||
|
|
||||||
|
# Run the ray workers asynchronously.
|
||||||
|
for worker in self.workers:
|
||||||
|
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
||||||
|
|
||||||
|
all_outputs = await asyncio.gather(*coros)
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
async def execute_model_async(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
all_outputs = await self._run_workers_async(
|
||||||
|
"execute_model",
|
||||||
|
driver_kwargs={
|
||||||
|
"seq_group_metadata_list": seq_group_metadata_list,
|
||||||
|
"blocks_to_swap_in": blocks_to_swap_in,
|
||||||
|
"blocks_to_swap_out": blocks_to_swap_out,
|
||||||
|
"blocks_to_copy": blocks_to_copy,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Only the driver worker returns the sampling results.
|
||||||
|
output = all_outputs[0]
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def check_health_async(self) -> None:
|
||||||
|
"""Raises an error if engine is unhealthy."""
|
||||||
|
self._check_if_any_actor_is_dead()
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_executor_class(load_in_low_bit):
|
||||||
|
return functools.partial(IPEXLLMGPUExecutor, load_in_low_bit=load_in_low_bit)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_executor_class_async(load_in_low_bit):
|
||||||
|
return functools.partial(IPEXLLMGPUExecutorAsync, load_in_low_bit=load_in_low_bit)
|
||||||
Loading…
Reference in a new issue