fix trl import when not running speculative (#13187)
* fix trl import when not running speculative * fix style
This commit is contained in:
parent
c5d919b151
commit
9df610f80d
1 changed files with 9 additions and 8 deletions
|
|
@ -50,14 +50,6 @@ from transformers import GenerationConfig, \
|
||||||
from ipex_llm.utils.common import log4Error
|
from ipex_llm.utils.common import log4Error
|
||||||
|
|
||||||
trans_version = transformers.__version__
|
trans_version = transformers.__version__
|
||||||
if version.parse(trans_version) >= version.parse("4.39.0"):
|
|
||||||
try:
|
|
||||||
from trl.core import top_k_top_p_filtering
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
log4Error.invalidInputError(False,
|
|
||||||
"For transformers version >= 4.39.0, pip install trl==0.11.0")
|
|
||||||
else:
|
|
||||||
from transformers import top_k_top_p_filtering
|
|
||||||
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
@ -164,6 +156,15 @@ def deepmind_sample(logits, return_probs: bool=False, top_k: int=50,
|
||||||
def logits_to_probs(logits, top_k: int=50, top_p: float=0.7, temperature: float=0.7):
|
def logits_to_probs(logits, top_k: int=50, top_p: float=0.7, temperature: float=0.7):
|
||||||
invalidInputError(top_k != 1 and top_p != 0.0 and temperature != 0.0,
|
invalidInputError(top_k != 1 and top_p != 0.0 and temperature != 0.0,
|
||||||
"top_k != 1 and top_p != 0.0 and temperature != 0.0 if do_sample=True")
|
"top_k != 1 and top_p != 0.0 and temperature != 0.0 if do_sample=True")
|
||||||
|
if version.parse(trans_version) >= version.parse("4.39.0"):
|
||||||
|
try:
|
||||||
|
from trl.core import top_k_top_p_filtering
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
log4Error.invalidInputError(False,
|
||||||
|
"For transformers version >= 4.39.0, "
|
||||||
|
"pip install trl==0.11.0")
|
||||||
|
else:
|
||||||
|
from transformers import top_k_top_p_filtering
|
||||||
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
|
_logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
|
||||||
top_k=top_k, top_p=top_p)
|
top_k=top_k, top_p=top_p)
|
||||||
prob_list = _logits.softmax(-1)
|
prob_list = _logits.softmax(-1)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue