diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 789fd67d..aa637577 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -153,8 +153,9 @@ def is_linear_module(module): VLLM_LINEAR_LIST = [ ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear, - ParallelLMHead ] + if 'xpu' in _VLLM_VERSION: + VLLM_LINEAR_LIST.append(ParallelLMHead) if is_module_in_classes(module, VLLM_LINEAR_LIST): if 'xpu' in _VLLM_VERSION: # For vllm xpu