Fix vllm tp (#11297)

This commit is contained in:
Guancheng Fu 2024-06-13 10:47:48 +08:00 committed by GitHub
parent 986af21896
commit 57a023aadc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -37,6 +37,7 @@
import platform
import torch
import torch.distributed
import torch.nn as nn
from accelerate import init_empty_weights
import warnings
@ -151,7 +152,10 @@ def is_linear_module(module):
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
tp_size = get_tensor_model_parallel_world_size()
if torch.distributed.is_initialized():
tp_size = get_tensor_model_parallel_world_size()
else:
tp_size = 1
else:
# For vllm cpu
tp_size = 1