diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9fa62956..e73fe09a 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -53,6 +53,7 @@ import subprocess import sys _IS_VLLM_AVAILABLE = None +_USE_VLLM = False def is_auto_gptq_available(): @@ -76,6 +77,10 @@ def is_vllm_available(): return _IS_VLLM_AVAILABLE +def get_use_vllm(): + return _USE_VLLM + + def is_torch_distributed_initialized(): return torch.distributed.is_initialized() @@ -119,14 +124,15 @@ def is_gptq_linear(module): def is_linear_module(module): + global _USE_VLLM + in_features = None out_features = None mp_group = None is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) - if is_vllm_available(): - # TODO: add tensor parallel feature later + # Only convert vllm modules from vllm.model_executor.layers.linear import ( ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ) @@ -148,16 +154,9 @@ def is_linear_module(module): in_features = module.input_size_per_partition elif isinstance(module, ColumnParallelLinear) and tp_size >= 2: out_features = module.output_size_per_partition - else: - # Also check for Linear module - if isinstance(module, nn.Linear) or is_awq: - in_features = module.in_features - out_features = module.out_features - mp_group = None - result = True - else: - result = False - elif is_gptq_linear(module): + _USE_VLLM = True + return result, (in_features, out_features, mp_group) + if is_gptq_linear(module): in_features = module.infeatures out_features = module.outfeatures mp_group = None diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 69ae91f1..a23126a0 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -53,7 +53,7 @@ from functools import reduce from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \ get_ipex_version -from ipex_llm.transformers.convert import is_deepspeed_available, is_vllm_available +from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm T = TypeVar("T", bound="torch.nn.Module") @@ -737,12 +737,11 @@ class LowBitLinear(nn.Linear): torch.xpu.empty_cache() result = result.view(new_shape) if self.mp_group is not None: - # FIXME: the user may install both vllm and deepspeed - if is_deepspeed_available(): + if get_use_vllm(): + torch.distributed.all_reduce(result, group=self.mp_group) + elif is_deepspeed_available(): from deepspeed import comm as dist dist.inference_all_reduce(result, group=self.mp_group) - elif is_vllm_available(): - torch.distributed.all_reduce(result, group=self.mp_group) else: invalidInputError(False, "mp_group is not None, but no supported backend found") if self.bias is not None: @@ -822,11 +821,11 @@ class FP16Linear(nn.Linear): self.weight_type = 2 result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) if self.mp_group is not None: - if is_deepspeed_available(): + if get_use_vllm(): + torch.distributed.all_reduce(result, group=self.mp_group) + elif is_deepspeed_available(): from deepspeed import comm as dist dist.inference_all_reduce(result, group=self.mp_group) - elif is_vllm_available(): - torch.distributed.all_reduce(result, group=self.mp_group) else: invalidInputError(False, "mp_group is not None, but no supported backend found") return result @@ -859,11 +858,11 @@ class FP16Linear(nn.Linear): new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) if self.mp_group is not None: - if is_deepspeed_available(): + if get_use_vllm(): + torch.distributed.all_reduce(result, group=self.mp_group) + elif is_deepspeed_available(): from deepspeed import comm as dist dist.inference_all_reduce(result, group=self.mp_group) - elif is_vllm_available(): - torch.distributed.all_reduce(result, group=self.mp_group) else: invalidInputError(False, "mp_group is not None, but no supported backend found") if self.bias is not None: