Fix vllm tp (#11297)
This commit is contained in:
parent
986af21896
commit
57a023aadc
1 changed files with 5 additions and 1 deletions
|
|
@ -37,6 +37,7 @@
|
|||
|
||||
import platform
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
import warnings
|
||||
|
|
@ -151,7 +152,10 @@ def is_linear_module(module):
|
|||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size
|
||||
)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if torch.distributed.is_initialized():
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
tp_size = 1
|
||||
else:
|
||||
# For vllm cpu
|
||||
tp_size = 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue