Updated importing of top_k_top_p_filtering for transformers>=4.39.0 (#10794)
* In transformers>=4.39.0, the top_k_top_p_filtering function has been deprecated and moved to the hugging face package trl. Thus, for versions >= 4.39.0, import this function from trl.
This commit is contained in:
parent
07e8b045a9
commit
1a885020ee
1 changed files with 12 additions and 1 deletions
|
|
@ -27,8 +27,19 @@ import logging
|
||||||
import transformers
|
import transformers
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
from transformers import top_k_top_p_filtering, GenerationConfig, \
|
from transformers import GenerationConfig, \
|
||||||
LogitsProcessorList, StoppingCriteriaList
|
LogitsProcessorList, StoppingCriteriaList
|
||||||
|
from ipex_llm.utils.common import log4Error
|
||||||
|
|
||||||
|
trans_version = transformers.__version__
|
||||||
|
if version.parse(trans_version) >= version.parse("4.39.0"):
|
||||||
|
try:
|
||||||
|
from trl.core import top_k_top_p_filtering
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
log4Error.invalidInputError(False, "For transformers version >= 4.39.0, pip install trl")
|
||||||
|
else:
|
||||||
|
from transformers import top_k_top_p_filtering
|
||||||
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue