Add tensor parallel for vLLM (#10879)
* initial * test initial tp * initial sup * fix format * fix * fix
This commit is contained in:
		
							parent
							
								
									d058f2b403
								
							
						
					
					
						commit
						990535b1cf
					
				
					 4 changed files with 507 additions and 10 deletions
				
			
		| 
						 | 
					@ -117,6 +117,10 @@ def is_linear_module(module):
 | 
				
			||||||
        from vllm.model_executor.layers.linear import (
 | 
					        from vllm.model_executor.layers.linear import (
 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
 | 
					            get_tensor_model_parallel_group,
 | 
				
			||||||
 | 
					            get_tensor_model_parallel_world_size
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        VLLM_LINEAR_LIST = [
 | 
					        VLLM_LINEAR_LIST = [
 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
| 
						 | 
					@ -125,6 +129,12 @@ def is_linear_module(module):
 | 
				
			||||||
            out_features = module.output_size
 | 
					            out_features = module.output_size
 | 
				
			||||||
            result = True
 | 
					            result = True
 | 
				
			||||||
            mp_group = None
 | 
					            mp_group = None
 | 
				
			||||||
 | 
					            tp_size = get_tensor_model_parallel_world_size()
 | 
				
			||||||
 | 
					            if isinstance(module, RowParallelLinear) and tp_size >= 2:
 | 
				
			||||||
 | 
					                mp_group = get_tensor_model_parallel_group()
 | 
				
			||||||
 | 
					                in_features = module.input_size_per_partition
 | 
				
			||||||
 | 
					            elif isinstance(module, ColumnParallelLinear) and tp_size >= 2:
 | 
				
			||||||
 | 
					                out_features = module.output_size_per_partition
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            result = False
 | 
					            result = False
 | 
				
			||||||
    elif is_gptq_linear(module):
 | 
					    elif is_gptq_linear(module):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,6 +45,7 @@ from typing import Optional, TypeVar, Union, overload
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.distributed
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from torch import Tensor, device, dtype, nn
 | 
					from torch import Tensor, device, dtype, nn
 | 
				
			||||||
from operator import mul
 | 
					from operator import mul
 | 
				
			||||||
| 
						 | 
					@ -52,6 +53,7 @@ from functools import reduce
 | 
				
			||||||
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
					from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
				
			||||||
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
 | 
					from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
 | 
				
			||||||
    get_ipex_version
 | 
					    get_ipex_version
 | 
				
			||||||
 | 
					from ipex_llm.transformers.convert import is_deepspeed_available, is_vllm_available
 | 
				
			||||||
 | 
					
 | 
				
			||||||
T = TypeVar("T", bound="torch.nn.Module")
 | 
					T = TypeVar("T", bound="torch.nn.Module")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -702,8 +704,14 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
                    torch.xpu.empty_cache()
 | 
					                    torch.xpu.empty_cache()
 | 
				
			||||||
            result = result.view(new_shape)
 | 
					            result = result.view(new_shape)
 | 
				
			||||||
            if self.mp_group is not None:
 | 
					            if self.mp_group is not None:
 | 
				
			||||||
                from deepspeed import comm as dist
 | 
					                # FIXME: the user may install both vllm and deepspeed
 | 
				
			||||||
                dist.inference_all_reduce(result, group=self.mp_group)
 | 
					                if is_deepspeed_available():
 | 
				
			||||||
 | 
					                    from deepspeed import comm as dist
 | 
				
			||||||
 | 
					                    dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					                elif is_vllm_available():
 | 
				
			||||||
 | 
					                    torch.distributed.all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    invalidInputError(False, "mp_group is not None, but no supported backend found")
 | 
				
			||||||
            if self.bias is not None:
 | 
					            if self.bias is not None:
 | 
				
			||||||
                result += self.bias
 | 
					                result += self.bias
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					@ -729,6 +737,7 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
                    result = result.view(new_shape)
 | 
					                    result = result.view(new_shape)
 | 
				
			||||||
            # allreduce to combine partial results and add bias if necessary
 | 
					            # allreduce to combine partial results and add bias if necessary
 | 
				
			||||||
            if self.mp_group is not None:
 | 
					            if self.mp_group is not None:
 | 
				
			||||||
 | 
					                # TODO: implement for CPU logic for vLLM tp
 | 
				
			||||||
                # deepspeed distibuted mode
 | 
					                # deepspeed distibuted mode
 | 
				
			||||||
                from deepspeed import comm as dist
 | 
					                from deepspeed import comm as dist
 | 
				
			||||||
                dist.inference_all_reduce(result, group=self.mp_group)
 | 
					                dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
| 
						 | 
					@ -780,8 +789,13 @@ class FP16Linear(nn.Linear):
 | 
				
			||||||
                    self.weight_type = 2
 | 
					                    self.weight_type = 2
 | 
				
			||||||
                result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
					                result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
				
			||||||
            if self.mp_group is not None:
 | 
					            if self.mp_group is not None:
 | 
				
			||||||
                from deepspeed import comm as dist
 | 
					                if is_deepspeed_available():
 | 
				
			||||||
                dist.inference_all_reduce(result, group=self.mp_group)
 | 
					                    from deepspeed import comm as dist
 | 
				
			||||||
 | 
					                    dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					                elif is_vllm_available():
 | 
				
			||||||
 | 
					                    torch.distributed.all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    invalidInputError(False, "mp_group is not None, but no supported backend found")
 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if self.in_len == 4096 and self.weight_type != 3 or \
 | 
					            if self.in_len == 4096 and self.weight_type != 3 or \
 | 
				
			||||||
| 
						 | 
					@ -817,8 +831,13 @@ class FP16Linear(nn.Linear):
 | 
				
			||||||
            new_shape = x_shape[:-1] + (self.out_len,)
 | 
					            new_shape = x_shape[:-1] + (self.out_len,)
 | 
				
			||||||
            result = result.view(new_shape)
 | 
					            result = result.view(new_shape)
 | 
				
			||||||
            if self.mp_group is not None:
 | 
					            if self.mp_group is not None:
 | 
				
			||||||
                from deepspeed import comm as dist
 | 
					                if is_deepspeed_available():
 | 
				
			||||||
                dist.inference_all_reduce(result, group=self.mp_group)
 | 
					                    from deepspeed import comm as dist
 | 
				
			||||||
 | 
					                    dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					                elif is_vllm_available():
 | 
				
			||||||
 | 
					                    torch.distributed.all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    invalidInputError(False, "mp_group is not None, but no supported backend found")
 | 
				
			||||||
            if self.bias is not None:
 | 
					            if self.bias is not None:
 | 
				
			||||||
                result += self.bias
 | 
					                result += self.bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,8 +45,9 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
				
			||||||
        parallel_config = engine_configs[2]
 | 
					        parallel_config = engine_configs[2]
 | 
				
			||||||
        if parallel_config.worker_use_ray or engine_args.engine_use_ray:
 | 
					        if parallel_config.worker_use_ray or engine_args.engine_use_ray:
 | 
				
			||||||
            initialize_ray_cluster(parallel_config)
 | 
					            initialize_ray_cluster(parallel_config)
 | 
				
			||||||
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
 | 
					            # from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
 | 
				
			||||||
            executor_class = RayGPUExecutorAsync
 | 
					            from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class_async
 | 
				
			||||||
 | 
					            executor_class = get_gpu_executor_class_async(load_in_low_bit)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(parallel_config.world_size == 1, (
 | 
					            invalidInputError(parallel_config.world_size == 1, (
 | 
				
			||||||
                "Ray is required if parallel_config.world_size > 1."))
 | 
					                "Ray is required if parallel_config.world_size > 1."))
 | 
				
			||||||
| 
						 | 
					@ -130,8 +131,9 @@ class IPEXLLMLLMEngine(LLMEngine):
 | 
				
			||||||
        # Initialize the cluster and specify the executor class.
 | 
					        # Initialize the cluster and specify the executor class.
 | 
				
			||||||
        if parallel_config.worker_use_ray:
 | 
					        if parallel_config.worker_use_ray:
 | 
				
			||||||
            initialize_ray_cluster(parallel_config)
 | 
					            initialize_ray_cluster(parallel_config)
 | 
				
			||||||
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
 | 
					            # from vllm.executor.ray_gpu_executor import RayGPUExecutor
 | 
				
			||||||
            executor_class = RayGPUExecutor
 | 
					            from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class
 | 
				
			||||||
 | 
					            executor_class = get_gpu_executor_class(load_in_low_bit)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(parallel_config.world_size == 1,
 | 
					            invalidInputError(parallel_config.world_size == 1,
 | 
				
			||||||
                              "Ray is required if parallel_config.world_size > 1.")
 | 
					                              "Ray is required if parallel_config.world_size > 1.")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										466
									
								
								python/llm/src/ipex_llm/vllm/ipex_llm_gpu_executor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										466
									
								
								python/llm/src/ipex_llm/vllm/ipex_llm_gpu_executor.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,466 @@
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
 | 
					from collections import defaultdict
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import pickle
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
 | 
					from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
 | 
				
			||||||
 | 
					                         ParallelConfig, SchedulerConfig, LoRAConfig)
 | 
				
			||||||
 | 
					from vllm.engine.ray_utils import RayWorkerVllm, ray
 | 
				
			||||||
 | 
					from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 | 
				
			||||||
 | 
					from vllm.executor.utils import check_block_size_valid
 | 
				
			||||||
 | 
					from vllm.logger import init_logger
 | 
				
			||||||
 | 
					from vllm.lora.request import LoRARequest
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput, SequenceGroupMetadata
 | 
				
			||||||
 | 
					from vllm.utils import (set_cuda_visible_devices, get_ip, get_open_port,
 | 
				
			||||||
 | 
					                        get_distributed_init_method, make_async)
 | 
				
			||||||
 | 
					import functools
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if ray is not None:
 | 
				
			||||||
 | 
					    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    from ray.util.placement_group import PlacementGroup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = init_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# A map between the device type (in device config) to its worker module.
 | 
				
			||||||
 | 
					DEVICE_TO_WORKER_MODULE_MAP = {
 | 
				
			||||||
 | 
					    "cuda": "vllm.worker.worker",
 | 
				
			||||||
 | 
					    "xpu": "vllm.worker.worker",
 | 
				
			||||||
 | 
					    "neuron": "vllm.worker.neuron_worker",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# If the env var is set, it uses the Ray's compiled DAG API
 | 
				
			||||||
 | 
					# which optimizes the control plane overhead.
 | 
				
			||||||
 | 
					# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
 | 
				
			||||||
 | 
					USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IPEXLLMGPUExecutor(ExecutorBase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        model_config: ModelConfig,
 | 
				
			||||||
 | 
					        cache_config: CacheConfig,
 | 
				
			||||||
 | 
					        parallel_config: ParallelConfig,
 | 
				
			||||||
 | 
					        scheduler_config: SchedulerConfig,
 | 
				
			||||||
 | 
					        device_config: DeviceConfig,
 | 
				
			||||||
 | 
					        lora_config: Optional[LoRAConfig],
 | 
				
			||||||
 | 
					        load_in_low_bit: str,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.model_config = model_config
 | 
				
			||||||
 | 
					        self.cache_config = cache_config
 | 
				
			||||||
 | 
					        self.lora_config = lora_config
 | 
				
			||||||
 | 
					        self.parallel_config = parallel_config
 | 
				
			||||||
 | 
					        self.scheduler_config = scheduler_config
 | 
				
			||||||
 | 
					        self.device_config = device_config
 | 
				
			||||||
 | 
					        self.load_in_low_bit = load_in_low_bit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        invalidInputError(self.parallel_config.worker_use_ray,
 | 
				
			||||||
 | 
					                          "worker_use_ray is False, but use ray worker")
 | 
				
			||||||
 | 
					        placement_group = self.parallel_config.placement_group
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Disable Ray usage stats collection.
 | 
				
			||||||
 | 
					        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
 | 
				
			||||||
 | 
					        if ray_usage != "1":
 | 
				
			||||||
 | 
					            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Create the parallel GPU workers.
 | 
				
			||||||
 | 
					        self._init_workers_ray(placement_group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Profile the memory usage and initialize the cache.
 | 
				
			||||||
 | 
					        self._init_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.forward_dag = None
 | 
				
			||||||
 | 
					        if USE_RAY_COMPILED_DAG:
 | 
				
			||||||
 | 
					            self.forward_dag = self._compiled_ray_dag()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _dispatch_worker(self):
 | 
				
			||||||
 | 
					        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
 | 
				
			||||||
 | 
					            self.device_config.device_type]
 | 
				
			||||||
 | 
					        imported_worker = importlib.import_module(worker_module)
 | 
				
			||||||
 | 
					        Worker = imported_worker.Worker
 | 
				
			||||||
 | 
					        return Worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _init_workers_ray(self, placement_group: "PlacementGroup",
 | 
				
			||||||
 | 
					                          **ray_remote_kwargs):
 | 
				
			||||||
 | 
					        if self.parallel_config.tensor_parallel_size == 1:
 | 
				
			||||||
 | 
					            # For single GPU case, we use a ray worker with constrained memory.
 | 
				
			||||||
 | 
					            num_gpus = self.cache_config.gpu_memory_utilization
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Otherwise, the ray workers are allocated with a full GPU.
 | 
				
			||||||
 | 
					            num_gpus = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # The driver dummy worker does not actually use any resources.
 | 
				
			||||||
 | 
					        # It holds the resource for the driver worker.
 | 
				
			||||||
 | 
					        self.driver_dummy_worker: RayWorkerVllm = None
 | 
				
			||||||
 | 
					        # The remaining workers are the actual ray actors.
 | 
				
			||||||
 | 
					        self.workers: List[RayWorkerVllm] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Create the workers.
 | 
				
			||||||
 | 
					        driver_ip = get_ip()
 | 
				
			||||||
 | 
					        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
 | 
				
			||||||
 | 
					            if not bundle.get("GPU", 0):
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            scheduling_strategy = PlacementGroupSchedulingStrategy(
 | 
				
			||||||
 | 
					                placement_group=placement_group,
 | 
				
			||||||
 | 
					                placement_group_capture_child_tasks=True,
 | 
				
			||||||
 | 
					                placement_group_bundle_index=bundle_id,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            worker = ray.remote(
 | 
				
			||||||
 | 
					                num_cpus=0,
 | 
				
			||||||
 | 
					                num_gpus=num_gpus,
 | 
				
			||||||
 | 
					                scheduling_strategy=scheduling_strategy,
 | 
				
			||||||
 | 
					                **ray_remote_kwargs,
 | 
				
			||||||
 | 
					            )(RayWorkerVllm).remote(self.model_config.trust_remote_code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            worker_ip = ray.get(worker.get_node_ip.remote())
 | 
				
			||||||
 | 
					            if worker_ip == driver_ip and self.driver_dummy_worker is None:
 | 
				
			||||||
 | 
					                # If the worker is on the same node as the driver, we use it
 | 
				
			||||||
 | 
					                # as the resource holder for the driver process.
 | 
				
			||||||
 | 
					                self.driver_dummy_worker = worker
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # Else, added to the list of workers.
 | 
				
			||||||
 | 
					                self.workers.append(worker)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.driver_dummy_worker is None:
 | 
				
			||||||
 | 
					            invalidInputError(False,
 | 
				
			||||||
 | 
					                              "Ray does not allocate any GPUs on the driver node. Consider "
 | 
				
			||||||
 | 
					                              "adjusting the Ray placement group or running the driver on a "
 | 
				
			||||||
 | 
					                              "GPU node.")
 | 
				
			||||||
 | 
					        # Get the set of GPU IDs used on each node.
 | 
				
			||||||
 | 
					        driver_node_id, driver_gpu_ids = ray.get(
 | 
				
			||||||
 | 
					            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
 | 
				
			||||||
 | 
					        worker_node_and_gpu_ids = ray.get(
 | 
				
			||||||
 | 
					            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        node_workers = defaultdict(list)
 | 
				
			||||||
 | 
					        node_gpus = defaultdict(list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        node_workers[driver_node_id].append(0)
 | 
				
			||||||
 | 
					        node_gpus[driver_node_id].extend(driver_gpu_ids)
 | 
				
			||||||
 | 
					        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
 | 
				
			||||||
 | 
					                                               start=1):
 | 
				
			||||||
 | 
					            node_workers[node_id].append(i)
 | 
				
			||||||
 | 
					            node_gpus[node_id].extend(gpu_ids)
 | 
				
			||||||
 | 
					        for node_id, gpu_ids in node_gpus.items():
 | 
				
			||||||
 | 
					            node_gpus[node_id] = sorted(gpu_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Set CUDA_VISIBLE_DEVICES for the driver and workers.
 | 
				
			||||||
 | 
					        set_cuda_visible_devices(node_gpus[driver_node_id])
 | 
				
			||||||
 | 
					        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
 | 
				
			||||||
 | 
					            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distributed_init_method = get_distributed_init_method(
 | 
				
			||||||
 | 
					            driver_ip, get_open_port())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Lazy import the Worker to avoid importing torch.cuda/xformers
 | 
				
			||||||
 | 
					        # before CUDA_VISIBLE_DEVICES is set in the Worker
 | 
				
			||||||
 | 
					        Worker = self._dispatch_worker()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model_config = copy.deepcopy(self.model_config)
 | 
				
			||||||
 | 
					        parallel_config = copy.deepcopy(self.parallel_config)
 | 
				
			||||||
 | 
					        scheduler_config = copy.deepcopy(self.scheduler_config)
 | 
				
			||||||
 | 
					        device_config = copy.deepcopy(self.device_config)
 | 
				
			||||||
 | 
					        lora_config = copy.deepcopy(self.lora_config)
 | 
				
			||||||
 | 
					        kv_cache_dtype = self.cache_config.cache_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Initialize the actual workers with the Worker class.
 | 
				
			||||||
 | 
					        for rank, (worker, (node_id, _)) in enumerate(
 | 
				
			||||||
 | 
					                zip(self.workers, worker_node_and_gpu_ids),
 | 
				
			||||||
 | 
					                start=1,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            local_rank = node_workers[node_id].index(rank)
 | 
				
			||||||
 | 
					            from ipex_llm.vllm.model_convert import _ipex_llm_convert
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def create_worker_function(rank, local_rank, load_in_low_bit):
 | 
				
			||||||
 | 
					                def worker_function():
 | 
				
			||||||
 | 
					                    _ipex_llm_convert(load_in_low_bit)
 | 
				
			||||||
 | 
					                    return Worker(
 | 
				
			||||||
 | 
					                        model_config,
 | 
				
			||||||
 | 
					                        parallel_config,
 | 
				
			||||||
 | 
					                        scheduler_config,
 | 
				
			||||||
 | 
					                        device_config,
 | 
				
			||||||
 | 
					                        local_rank,
 | 
				
			||||||
 | 
					                        rank,
 | 
				
			||||||
 | 
					                        distributed_init_method,
 | 
				
			||||||
 | 
					                        lora_config=lora_config,
 | 
				
			||||||
 | 
					                        kv_cache_dtype=kv_cache_dtype,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                return worker_function
 | 
				
			||||||
 | 
					            worker.init_worker.remote(create_worker_function(rank,
 | 
				
			||||||
 | 
					                                                             local_rank,
 | 
				
			||||||
 | 
					                                                             self.load_in_low_bit))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Initialize the driver worker with the Worker class.
 | 
				
			||||||
 | 
					        driver_rank = 0
 | 
				
			||||||
 | 
					        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
 | 
				
			||||||
 | 
					        self.driver_worker = Worker(
 | 
				
			||||||
 | 
					            self.model_config,
 | 
				
			||||||
 | 
					            self.parallel_config,
 | 
				
			||||||
 | 
					            self.scheduler_config,
 | 
				
			||||||
 | 
					            self.device_config,
 | 
				
			||||||
 | 
					            driver_local_rank,
 | 
				
			||||||
 | 
					            driver_rank,
 | 
				
			||||||
 | 
					            distributed_init_method,
 | 
				
			||||||
 | 
					            lora_config=self.lora_config,
 | 
				
			||||||
 | 
					            kv_cache_dtype=kv_cache_dtype,
 | 
				
			||||||
 | 
					            is_driver_worker=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # We want to apply patch here before we loading the model
 | 
				
			||||||
 | 
					        # FIXME(woosuk): We are not properly initializing cupy NCCL when
 | 
				
			||||||
 | 
					        # we have multiple nodes.
 | 
				
			||||||
 | 
					        self._run_workers("init_model",
 | 
				
			||||||
 | 
					                          cupy_port=get_open_port()
 | 
				
			||||||
 | 
					                          if not model_config.enforce_eager else None)
 | 
				
			||||||
 | 
					        self._run_workers(
 | 
				
			||||||
 | 
					            "load_model",
 | 
				
			||||||
 | 
					            max_concurrent_workers=self.parallel_config.
 | 
				
			||||||
 | 
					            max_parallel_loading_workers,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _init_cache(self) -> None:
 | 
				
			||||||
 | 
					        """Profiles the memory usage and initializes the KV cache.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The engine will first conduct a profiling of the existing memory usage.
 | 
				
			||||||
 | 
					        Then, it calculate the maximum possible number of GPU and CPU blocks
 | 
				
			||||||
 | 
					        that can be allocated with the remaining free memory.
 | 
				
			||||||
 | 
					        More details can be found in the
 | 
				
			||||||
 | 
					        :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
 | 
				
			||||||
 | 
					        from class :class:`~vllm.worker.Worker`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Afterwards, as there may be multiple workers,
 | 
				
			||||||
 | 
					        we take the minimum number of blocks across all workers
 | 
				
			||||||
 | 
					        to ensure this can be applied to all of them.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Finally, the engine will initialize the KV cache
 | 
				
			||||||
 | 
					        with the calculated number of blocks.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        .. tip::
 | 
				
			||||||
 | 
					            You may limit the usage of GPU memory
 | 
				
			||||||
 | 
					            by adjusting the `gpu_memory_utilization` parameter.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Get the maximum number of blocks that can be allocated on GPU and CPU.
 | 
				
			||||||
 | 
					        num_blocks = self._run_workers(
 | 
				
			||||||
 | 
					            "profile_num_available_blocks",
 | 
				
			||||||
 | 
					            block_size=self.cache_config.block_size,
 | 
				
			||||||
 | 
					            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
 | 
				
			||||||
 | 
					            cpu_swap_space=self.cache_config.swap_space_bytes,
 | 
				
			||||||
 | 
					            cache_dtype=self.cache_config.cache_dtype,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Since we use a shared centralized controller, we take the minimum
 | 
				
			||||||
 | 
					        # number of blocks across all workers to make sure all the memory
 | 
				
			||||||
 | 
					        # operators can be applied to all workers.
 | 
				
			||||||
 | 
					        num_gpu_blocks = min(b[0] for b in num_blocks)
 | 
				
			||||||
 | 
					        num_cpu_blocks = min(b[1] for b in num_blocks)
 | 
				
			||||||
 | 
					        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
 | 
				
			||||||
 | 
					                    f"# CPU blocks: {num_cpu_blocks}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
 | 
				
			||||||
 | 
					                               self.model_config.max_model_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.cache_config.num_gpu_blocks = num_gpu_blocks
 | 
				
			||||||
 | 
					        self.cache_config.num_cpu_blocks = num_cpu_blocks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Initialize the cache.
 | 
				
			||||||
 | 
					        self._run_workers("init_cache_engine", cache_config=self.cache_config)
 | 
				
			||||||
 | 
					        # Warm up the model. This includes capturing the model into CUDA graph
 | 
				
			||||||
 | 
					        # if enforce_eager is False.
 | 
				
			||||||
 | 
					        self._run_workers("warm_up_model")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def execute_model(self,
 | 
				
			||||||
 | 
					                      seq_group_metadata_list: List[SequenceGroupMetadata],
 | 
				
			||||||
 | 
					                      blocks_to_swap_in: Dict[int, int],
 | 
				
			||||||
 | 
					                      blocks_to_swap_out: Dict[int, int],
 | 
				
			||||||
 | 
					                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
 | 
				
			||||||
 | 
					        all_outputs = self._run_workers(
 | 
				
			||||||
 | 
					            "execute_model",
 | 
				
			||||||
 | 
					            driver_kwargs={
 | 
				
			||||||
 | 
					                "seq_group_metadata_list": seq_group_metadata_list,
 | 
				
			||||||
 | 
					                "blocks_to_swap_in": blocks_to_swap_in,
 | 
				
			||||||
 | 
					                "blocks_to_swap_out": blocks_to_swap_out,
 | 
				
			||||||
 | 
					                "blocks_to_copy": blocks_to_copy,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Only the driver worker returns the sampling results.
 | 
				
			||||||
 | 
					        output = all_outputs[0]
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_lora(self, lora_request: LoRARequest) -> bool:
 | 
				
			||||||
 | 
					        invalidInputError(lora_request.lora_int_id > 0,
 | 
				
			||||||
 | 
					                          "lora_id must be greater than 0.")
 | 
				
			||||||
 | 
					        return self._run_workers(
 | 
				
			||||||
 | 
					            "add_lora",
 | 
				
			||||||
 | 
					            lora_request=lora_request,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_lora(self, lora_id: int) -> bool:
 | 
				
			||||||
 | 
					        invalidInputError(lora_id > 0, "lora_id must be greater than 0.")
 | 
				
			||||||
 | 
					        return self._run_workers(
 | 
				
			||||||
 | 
					            "remove_lora",
 | 
				
			||||||
 | 
					            lora_id=lora_id,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def list_loras(self) -> List[int]:
 | 
				
			||||||
 | 
					        return self._run_workers("list_loras")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _run_workers(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        method: str,
 | 
				
			||||||
 | 
					        *args,
 | 
				
			||||||
 | 
					        driver_args: Optional[List[Any]]=None,
 | 
				
			||||||
 | 
					        driver_kwargs: Optional[Dict[str, Any]]=None,
 | 
				
			||||||
 | 
					        max_concurrent_workers: Optional[int] = None,
 | 
				
			||||||
 | 
					        use_ray_compiled_dag: bool = False,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> Any:
 | 
				
			||||||
 | 
					        """Runs the given method on all workers."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if max_concurrent_workers:
 | 
				
			||||||
 | 
					            invalidInputError(False,
 | 
				
			||||||
 | 
					                              "max_concurrent_workers is not supported yet.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_ray_compiled_dag:
 | 
				
			||||||
 | 
					            # Right now, compiled DAG can only accept a single
 | 
				
			||||||
 | 
					            # input. TODO(sang): Fix it.
 | 
				
			||||||
 | 
					            output_channels = self.forward_dag.execute(1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Start the ray workers first.
 | 
				
			||||||
 | 
					            ray_worker_outputs = [
 | 
				
			||||||
 | 
					                worker.execute_method.remote(method, *args, **kwargs)
 | 
				
			||||||
 | 
					                for worker in self.workers
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if driver_args is None:
 | 
				
			||||||
 | 
					            driver_args = args
 | 
				
			||||||
 | 
					        if driver_kwargs is None:
 | 
				
			||||||
 | 
					            driver_kwargs = kwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Start the driver worker after all the ray workers.
 | 
				
			||||||
 | 
					        driver_worker_output = getattr(self.driver_worker,
 | 
				
			||||||
 | 
					                                       method)(*driver_args, **driver_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Get the results of the ray workers.
 | 
				
			||||||
 | 
					        if self.workers:
 | 
				
			||||||
 | 
					            if use_ray_compiled_dag:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    ray_worker_outputs = [
 | 
				
			||||||
 | 
					                        pickle.loads(chan.begin_read())
 | 
				
			||||||
 | 
					                        for chan in output_channels
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					                finally:
 | 
				
			||||||
 | 
					                    # Has to call end_read in order to reuse the DAG.
 | 
				
			||||||
 | 
					                    for chan in output_channels:
 | 
				
			||||||
 | 
					                        chan.end_read()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                ray_worker_outputs = ray.get(ray_worker_outputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return [driver_worker_output] + ray_worker_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _compiled_ray_dag(self):
 | 
				
			||||||
 | 
					        import pkg_resources
 | 
				
			||||||
 | 
					        required_version = "2.9"
 | 
				
			||||||
 | 
					        current_version = pkg_resources.get_distribution("ray").version
 | 
				
			||||||
 | 
					        if current_version < required_version:
 | 
				
			||||||
 | 
					            invalidInputError(False,
 | 
				
			||||||
 | 
					                              f"Ray version {required_version} or greater is "
 | 
				
			||||||
 | 
					                              f"required, but found {current_version}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        from ray.dag import MultiOutputNode, InputNode
 | 
				
			||||||
 | 
					        invalidInputError(self.parallel_config.worker_use_ray,
 | 
				
			||||||
 | 
					                          "Use ray worker, but worker_use_ray is False")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Right now, compiled DAG requires at least 1 arg. We send
 | 
				
			||||||
 | 
					        # a dummy value for now. It will be fixed soon.
 | 
				
			||||||
 | 
					        with InputNode() as input_data:
 | 
				
			||||||
 | 
					            forward_dag = MultiOutputNode([
 | 
				
			||||||
 | 
					                worker.execute_model_compiled_dag_remote.bind(input_data)
 | 
				
			||||||
 | 
					                for worker in self.workers
 | 
				
			||||||
 | 
					            ])
 | 
				
			||||||
 | 
					        return forward_dag.experimental_compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_health(self) -> None:
 | 
				
			||||||
 | 
					        """Raises an error if engine is unhealthy."""
 | 
				
			||||||
 | 
					        self._check_if_any_actor_is_dead()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _check_if_any_actor_is_dead(self):
 | 
				
			||||||
 | 
					        if not self.workers:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        dead_actors = []
 | 
				
			||||||
 | 
					        for actor in self.workers:
 | 
				
			||||||
 | 
					            actor_state = ray.state.actors(actor._ray_actor_id.hex())
 | 
				
			||||||
 | 
					            if actor_state["State"] == "DEAD":
 | 
				
			||||||
 | 
					                dead_actors.append(actor)
 | 
				
			||||||
 | 
					        if dead_actors:
 | 
				
			||||||
 | 
					            invalidInputError("At least one Worker is dead. "
 | 
				
			||||||
 | 
					                              f"Dead Workers: {dead_actors}. ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IPEXLLMGPUExecutorAsync(IPEXLLMGPUExecutor, ExecutorAsyncBase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _run_workers_async(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        method: str,
 | 
				
			||||||
 | 
					        *args,
 | 
				
			||||||
 | 
					        driver_args: Optional[List[Any]]=None,
 | 
				
			||||||
 | 
					        driver_kwargs: Optional[Dict[str, Any]]=None,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> Any:
 | 
				
			||||||
 | 
					        """Runs the given method on all workers."""
 | 
				
			||||||
 | 
					        coros = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if driver_args is None:
 | 
				
			||||||
 | 
					            driver_args = args
 | 
				
			||||||
 | 
					        if driver_kwargs is None:
 | 
				
			||||||
 | 
					            driver_kwargs = kwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Run the driver worker asynchronously.
 | 
				
			||||||
 | 
					        driver_executor = make_async(getattr(self.driver_worker, method))
 | 
				
			||||||
 | 
					        coros.append(driver_executor(*driver_args, **driver_kwargs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Run the ray workers asynchronously.
 | 
				
			||||||
 | 
					        for worker in self.workers:
 | 
				
			||||||
 | 
					            coros.append(worker.execute_method.remote(method, *args, **kwargs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        all_outputs = await asyncio.gather(*coros)
 | 
				
			||||||
 | 
					        return all_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def execute_model_async(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        seq_group_metadata_list: List[SequenceGroupMetadata],
 | 
				
			||||||
 | 
					        blocks_to_swap_in: Dict[int, int],
 | 
				
			||||||
 | 
					        blocks_to_swap_out: Dict[int, int],
 | 
				
			||||||
 | 
					        blocks_to_copy: Dict[int, List[int]],
 | 
				
			||||||
 | 
					    ) -> SamplerOutput:
 | 
				
			||||||
 | 
					        all_outputs = await self._run_workers_async(
 | 
				
			||||||
 | 
					            "execute_model",
 | 
				
			||||||
 | 
					            driver_kwargs={
 | 
				
			||||||
 | 
					                "seq_group_metadata_list": seq_group_metadata_list,
 | 
				
			||||||
 | 
					                "blocks_to_swap_in": blocks_to_swap_in,
 | 
				
			||||||
 | 
					                "blocks_to_swap_out": blocks_to_swap_out,
 | 
				
			||||||
 | 
					                "blocks_to_copy": blocks_to_copy,
 | 
				
			||||||
 | 
					            })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Only the driver worker returns the sampling results.
 | 
				
			||||||
 | 
					        output = all_outputs[0]
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def check_health_async(self) -> None:
 | 
				
			||||||
 | 
					        """Raises an error if engine is unhealthy."""
 | 
				
			||||||
 | 
					        self._check_if_any_actor_is_dead()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_gpu_executor_class(load_in_low_bit):
 | 
				
			||||||
 | 
					    return functools.partial(IPEXLLMGPUExecutor, load_in_low_bit=load_in_low_bit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_gpu_executor_class_async(load_in_low_bit):
 | 
				
			||||||
 | 
					    return functools.partial(IPEXLLMGPUExecutorAsync, load_in_low_bit=load_in_low_bit)
 | 
				
			||||||
		Loading…
	
		Reference in a new issue