diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 20540c56..9d8880cc 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -27,8 +27,19 @@ import logging import transformers from packaging import version from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -from transformers import top_k_top_p_filtering, GenerationConfig, \ +from transformers import GenerationConfig, \ LogitsProcessorList, StoppingCriteriaList +from ipex_llm.utils.common import log4Error + +trans_version = transformers.__version__ +if version.parse(trans_version) >= version.parse("4.39.0"): + try: + from trl.core import top_k_top_p_filtering + except ModuleNotFoundError: + log4Error.invalidInputError(False, "For transformers version >= 4.39.0, pip install trl") +else: + from transformers import top_k_top_p_filtering + from ipex_llm.utils.common import invalidInputError from transformers.modeling_outputs import CausalLMOutputWithPast