Fix vllm condition (#11169)

* add use-vllm

* done

* fix style

* fix done
This commit is contained in:
Guancheng Fu 2024-05-30 15:23:17 +08:00 committed by GitHub
parent dcbf4d3d0a
commit 50ee004ac7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 23 deletions

View file

@ -53,6 +53,7 @@ import subprocess
import sys
_IS_VLLM_AVAILABLE = None
_USE_VLLM = False
def is_auto_gptq_available():
@ -76,6 +77,10 @@ def is_vllm_available():
return _IS_VLLM_AVAILABLE
def get_use_vllm():
return _USE_VLLM
def is_torch_distributed_initialized():
return torch.distributed.is_initialized()
@ -119,14 +124,15 @@ def is_gptq_linear(module):
def is_linear_module(module):
global _USE_VLLM
in_features = None
out_features = None
mp_group = None
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
if is_vllm_available():
# TODO: add tensor parallel feature later
# Only convert vllm modules
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
@ -148,16 +154,9 @@ def is_linear_module(module):
in_features = module.input_size_per_partition
elif isinstance(module, ColumnParallelLinear) and tp_size >= 2:
out_features = module.output_size_per_partition
else:
# Also check for Linear module
if isinstance(module, nn.Linear) or is_awq:
in_features = module.in_features
out_features = module.out_features
mp_group = None
result = True
else:
result = False
elif is_gptq_linear(module):
_USE_VLLM = True
return result, (in_features, out_features, mp_group)
if is_gptq_linear(module):
in_features = module.infeatures
out_features = module.outfeatures
mp_group = None

View file

@ -53,7 +53,7 @@ from functools import reduce
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
get_ipex_version
from ipex_llm.transformers.convert import is_deepspeed_available, is_vllm_available
from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
T = TypeVar("T", bound="torch.nn.Module")
@ -737,12 +737,11 @@ class LowBitLinear(nn.Linear):
torch.xpu.empty_cache()
result = result.view(new_shape)
if self.mp_group is not None:
# FIXME: the user may install both vllm and deepspeed
if is_deepspeed_available():
if get_use_vllm():
torch.distributed.all_reduce(result, group=self.mp_group)
elif is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None:
@ -822,11 +821,11 @@ class FP16Linear(nn.Linear):
self.weight_type = 2
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
if self.mp_group is not None:
if is_deepspeed_available():
if get_use_vllm():
torch.distributed.all_reduce(result, group=self.mp_group)
elif is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
return result
@ -859,11 +858,11 @@ class FP16Linear(nn.Linear):
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
if is_deepspeed_available():
if get_use_vllm():
torch.distributed.all_reduce(result, group=self.mp_group)
elif is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None: