diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 55565a49..e631b674 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -37,6 +37,7 @@ import platform import torch +import torch.distributed import torch.nn as nn from accelerate import init_empty_weights import warnings @@ -151,7 +152,10 @@ def is_linear_module(module): get_tensor_model_parallel_group, get_tensor_model_parallel_world_size ) - tp_size = get_tensor_model_parallel_world_size() + if torch.distributed.is_initialized(): + tp_size = get_tensor_model_parallel_world_size() + else: + tp_size = 1 else: # For vllm cpu tp_size = 1