Fix error while using pipeline parallism (#11434)

This commit is contained in:
Guancheng Fu 2024-06-26 15:33:47 +08:00 committed by GitHub
parent a45ceac4e4
commit 99cd16ef9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -146,6 +146,14 @@ def is_linear_module(module):
global _VLLM_VERSION
if _VLLM_VERSION is None:
_VLLM_VERSION = get_package_version('vllm')
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
]
if is_module_in_classes(module, VLLM_LINEAR_LIST):
if 'xpu' in _VLLM_VERSION:
# For vllm xpu
from vllm.model_executor.parallel_utils.parallel_state import (
@ -159,15 +167,6 @@ def is_linear_module(module):
else:
# For vllm cpu
tp_size = 1
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
]
if is_module_in_classes(module, VLLM_LINEAR_LIST):
in_features = module.input_size
out_features = module.output_size
result = True