Fix error while using pipeline parallism (#11434)

This commit is contained in:
Guancheng Fu 2024-06-26 15:33:47 +08:00 committed by GitHub
parent a45ceac4e4
commit 99cd16ef9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -146,20 +146,6 @@ def is_linear_module(module):
global _VLLM_VERSION global _VLLM_VERSION
if _VLLM_VERSION is None: if _VLLM_VERSION is None:
_VLLM_VERSION = get_package_version('vllm') _VLLM_VERSION = get_package_version('vllm')
if 'xpu' in _VLLM_VERSION:
# For vllm xpu
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
if torch.distributed.is_initialized():
tp_size = get_tensor_model_parallel_world_size()
else:
tp_size = 1
else:
# For vllm cpu
tp_size = 1
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
) )
@ -168,6 +154,19 @@ def is_linear_module(module):
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
] ]
if is_module_in_classes(module, VLLM_LINEAR_LIST): if is_module_in_classes(module, VLLM_LINEAR_LIST):
if 'xpu' in _VLLM_VERSION:
# For vllm xpu
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
if torch.distributed.is_initialized():
tp_size = get_tensor_model_parallel_world_size()
else:
tp_size = 1
else:
# For vllm cpu
tp_size = 1
in_features = module.input_size in_features = module.input_size
out_features = module.output_size out_features = module.output_size
result = True result = True