Updated importing of top_k_top_p_filtering for transformers>=4.39.0 (#10794)

* In transformers>=4.39.0, the top_k_top_p_filtering function has been deprecated and moved to the hugging face package trl. Thus, for versions >= 4.39.0, import this function from trl.
This commit is contained in:
Ovo233 2024-04-19 15:34:39 +08:00 committed by GitHub
parent 07e8b045a9
commit 1a885020ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -27,8 +27,19 @@ import logging
import transformers import transformers
from packaging import version from packaging import version
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import top_k_top_p_filtering, GenerationConfig, \ from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList LogitsProcessorList, StoppingCriteriaList
from ipex_llm.utils.common import log4Error
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.39.0"):
try:
from trl.core import top_k_top_p_filtering
except ModuleNotFoundError:
log4Error.invalidInputError(False, "For transformers version >= 4.39.0, pip install trl")
else:
from transformers import top_k_top_p_filtering
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast