Fix vllm condition (#11169)
* add use-vllm * done * fix style * fix done
This commit is contained in:
parent
dcbf4d3d0a
commit
50ee004ac7
2 changed files with 21 additions and 23 deletions
|
|
@ -53,6 +53,7 @@ import subprocess
|
|||
import sys
|
||||
|
||||
_IS_VLLM_AVAILABLE = None
|
||||
_USE_VLLM = False
|
||||
|
||||
|
||||
def is_auto_gptq_available():
|
||||
|
|
@ -76,6 +77,10 @@ def is_vllm_available():
|
|||
return _IS_VLLM_AVAILABLE
|
||||
|
||||
|
||||
def get_use_vllm():
|
||||
return _USE_VLLM
|
||||
|
||||
|
||||
def is_torch_distributed_initialized():
|
||||
return torch.distributed.is_initialized()
|
||||
|
||||
|
|
@ -119,14 +124,15 @@ def is_gptq_linear(module):
|
|||
|
||||
def is_linear_module(module):
|
||||
|
||||
global _USE_VLLM
|
||||
|
||||
in_features = None
|
||||
out_features = None
|
||||
mp_group = None
|
||||
|
||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||
|
||||
if is_vllm_available():
|
||||
# TODO: add tensor parallel feature later
|
||||
# Only convert vllm modules
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||
)
|
||||
|
|
@ -148,16 +154,9 @@ def is_linear_module(module):
|
|||
in_features = module.input_size_per_partition
|
||||
elif isinstance(module, ColumnParallelLinear) and tp_size >= 2:
|
||||
out_features = module.output_size_per_partition
|
||||
else:
|
||||
# Also check for Linear module
|
||||
if isinstance(module, nn.Linear) or is_awq:
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
mp_group = None
|
||||
result = True
|
||||
else:
|
||||
result = False
|
||||
elif is_gptq_linear(module):
|
||||
_USE_VLLM = True
|
||||
return result, (in_features, out_features, mp_group)
|
||||
if is_gptq_linear(module):
|
||||
in_features = module.infeatures
|
||||
out_features = module.outfeatures
|
||||
mp_group = None
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ from functools import reduce
|
|||
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
|
||||
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
|
||||
get_ipex_version
|
||||
from ipex_llm.transformers.convert import is_deepspeed_available, is_vllm_available
|
||||
from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
|
@ -737,12 +737,11 @@ class LowBitLinear(nn.Linear):
|
|||
torch.xpu.empty_cache()
|
||||
result = result.view(new_shape)
|
||||
if self.mp_group is not None:
|
||||
# FIXME: the user may install both vllm and deepspeed
|
||||
if is_deepspeed_available():
|
||||
if get_use_vllm():
|
||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||
elif is_deepspeed_available():
|
||||
from deepspeed import comm as dist
|
||||
dist.inference_all_reduce(result, group=self.mp_group)
|
||||
elif is_vllm_available():
|
||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||
else:
|
||||
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||
if self.bias is not None:
|
||||
|
|
@ -822,11 +821,11 @@ class FP16Linear(nn.Linear):
|
|||
self.weight_type = 2
|
||||
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
||||
if self.mp_group is not None:
|
||||
if is_deepspeed_available():
|
||||
if get_use_vllm():
|
||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||
elif is_deepspeed_available():
|
||||
from deepspeed import comm as dist
|
||||
dist.inference_all_reduce(result, group=self.mp_group)
|
||||
elif is_vllm_available():
|
||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||
else:
|
||||
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||
return result
|
||||
|
|
@ -859,11 +858,11 @@ class FP16Linear(nn.Linear):
|
|||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
if self.mp_group is not None:
|
||||
if is_deepspeed_available():
|
||||
if get_use_vllm():
|
||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||
elif is_deepspeed_available():
|
||||
from deepspeed import comm as dist
|
||||
dist.inference_all_reduce(result, group=self.mp_group)
|
||||
elif is_vllm_available():
|
||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
||||
else:
|
||||
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||
if self.bias is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue