From 57a023aadc15b7d2aa5e48ad9a84c98152ca8acc Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:47:48 +0800 Subject: [PATCH] Fix vllm tp (#11297) --- python/llm/src/ipex_llm/transformers/convert.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 55565a49..e631b674 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -37,6 +37,7 @@ import platform import torch +import torch.distributed import torch.nn as nn from accelerate import init_empty_weights import warnings @@ -151,7 +152,10 @@ def is_linear_module(module): get_tensor_model_parallel_group, get_tensor_model_parallel_world_size ) - tp_size = get_tensor_model_parallel_world_size() + if torch.distributed.is_initialized(): + tp_size = get_tensor_model_parallel_world_size() + else: + tp_size = 1 else: # For vllm cpu tp_size = 1