Add tensor parallel for vLLM (#10879)

* initial

* test initial tp

* initial sup

* fix format

* fix

* fix
This commit is contained in:
Guancheng Fu 2024-04-26 17:10:49 +08:00 committed by GitHub
parent d058f2b403
commit 990535b1cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 507 additions and 10 deletions

View file

@ -117,6 +117,10 @@ def is_linear_module(module):
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
]
@ -125,6 +129,12 @@ def is_linear_module(module):
out_features = module.output_size
result = True
mp_group = None
tp_size = get_tensor_model_parallel_world_size()
if isinstance(module, RowParallelLinear) and tp_size >= 2:
mp_group = get_tensor_model_parallel_group()
in_features = module.input_size_per_partition
elif isinstance(module, ColumnParallelLinear) and tp_size >= 2:
out_features = module.output_size_per_partition
else:
result = False
elif is_gptq_linear(module):

View file

@ -45,6 +45,7 @@ from typing import Optional, TypeVar, Union, overload
from ipex_llm.utils.common import invalidInputError
import os
import torch
import torch.distributed
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from operator import mul
@ -52,6 +53,7 @@ from functools import reduce
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
get_ipex_version
from ipex_llm.transformers.convert import is_deepspeed_available, is_vllm_available
T = TypeVar("T", bound="torch.nn.Module")
@ -702,8 +704,14 @@ class LowBitLinear(nn.Linear):
torch.xpu.empty_cache()
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
# FIXME: the user may install both vllm and deepspeed
if is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None:
result += self.bias
else:
@ -729,6 +737,7 @@ class LowBitLinear(nn.Linear):
result = result.view(new_shape)
# allreduce to combine partial results and add bias if necessary
if self.mp_group is not None:
# TODO: implement for CPU logic for vLLM tp
# deepspeed distibuted mode
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
@ -780,8 +789,13 @@ class FP16Linear(nn.Linear):
self.weight_type = 2
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
return result
else:
if self.in_len == 4096 and self.weight_type != 3 or \
@ -817,8 +831,13 @@ class FP16Linear(nn.Linear):
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None:
result += self.bias

View file

@ -45,8 +45,9 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
# from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class_async
executor_class = get_gpu_executor_class_async(load_in_low_bit)
else:
invalidInputError(parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
@ -130,8 +131,9 @@ class IPEXLLMLLMEngine(LLMEngine):
# Initialize the cluster and specify the executor class.
if parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
# from vllm.executor.ray_gpu_executor import RayGPUExecutor
from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class
executor_class = get_gpu_executor_class(load_in_low_bit)
else:
invalidInputError(parallel_config.world_size == 1,
"Ray is required if parallel_config.world_size > 1.")

View file

@ -0,0 +1,466 @@
import asyncio
import copy
from collections import defaultdict
import os
import pickle
import importlib
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (set_cuda_visible_devices, get_ip, get_open_port,
get_distributed_init_method, make_async)
import functools
from ipex_llm.utils.common import invalidInputError
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"xpu": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class IPEXLLMGPUExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
load_in_low_bit: str,
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.load_in_low_bit = load_in_low_bit
invalidInputError(self.parallel_config.worker_use_ray,
"worker_use_ray is False, but use ray worker")
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self._init_cache()
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerVllm = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = []
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
invalidInputError(False,
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
# Get the set of GPU IDs used on each node.
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
kv_cache_dtype = self.cache_config.cache_dtype
# Initialize the actual workers with the Worker class.
for rank, (worker, (node_id, _)) in enumerate(
zip(self.workers, worker_node_and_gpu_ids),
start=1,
):
local_rank = node_workers[node_id].index(rank)
from ipex_llm.vllm.model_convert import _ipex_llm_convert
def create_worker_function(rank, local_rank, load_in_low_bit):
def worker_function():
_ipex_llm_convert(load_in_low_bit)
return Worker(
model_config,
parallel_config,
scheduler_config,
device_config,
local_rank,
rank,
distributed_init_method,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
)
return worker_function
worker.init_worker.remote(create_worker_function(rank,
local_rank,
self.load_in_low_bit))
# Initialize the driver worker with the Worker class.
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
driver_local_rank,
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=True,
)
# We want to apply patch here before we loading the model
# FIXME(woosuk): We are not properly initializing cupy NCCL when
# we have multiple nodes.
self._run_workers("init_model",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
invalidInputError(lora_request.lora_int_id > 0,
"lora_id must be greater than 0.")
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
invalidInputError(lora_id > 0, "lora_id must be greater than 0.")
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
return self._run_workers("list_loras")
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[List[Any]]=None,
driver_kwargs: Optional[Dict[str, Any]]=None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if max_concurrent_workers:
invalidInputError(False,
"max_concurrent_workers is not supported yet.")
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
invalidInputError(False,
f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import MultiOutputNode, InputNode
invalidInputError(self.parallel_config.worker_use_ray,
"Use ray worker, but worker_use_ray is False")
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.bind(input_data)
for worker in self.workers
])
return forward_dag.experimental_compile()
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex())
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
invalidInputError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")
class IPEXLLMGPUExecutorAsync(IPEXLLMGPUExecutor, ExecutorAsyncBase):
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]]=None,
driver_kwargs: Optional[Dict[str, Any]]=None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Run the driver worker asynchronously.
driver_executor = make_async(getattr(self.driver_worker, method))
coros.append(driver_executor(*driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def get_gpu_executor_class(load_in_low_bit):
return functools.partial(IPEXLLMGPUExecutor, load_in_low_bit=load_in_low_bit)
def get_gpu_executor_class_async(load_in_low_bit):
return functools.partial(IPEXLLMGPUExecutorAsync, load_in_low_bit=load_in_low_bit)