fix trl import when not running speculative (#13187)
* fix trl import when not running speculative * fix style
This commit is contained in:
		
							parent
							
								
									c5d919b151
								
							
						
					
					
						commit
						9df610f80d
					
				
					 1 changed files with 9 additions and 8 deletions
				
			
		| 
						 | 
					@ -50,14 +50,6 @@ from transformers import GenerationConfig, \
 | 
				
			||||||
from ipex_llm.utils.common import log4Error
 | 
					from ipex_llm.utils.common import log4Error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
trans_version = transformers.__version__
 | 
					trans_version = transformers.__version__
 | 
				
			||||||
if version.parse(trans_version) >= version.parse("4.39.0"):
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        from trl.core import top_k_top_p_filtering
 | 
					 | 
				
			||||||
    except ModuleNotFoundError:
 | 
					 | 
				
			||||||
        log4Error.invalidInputError(False,
 | 
					 | 
				
			||||||
                                    "For transformers version >= 4.39.0, pip install trl==0.11.0")
 | 
					 | 
				
			||||||
else:
 | 
					 | 
				
			||||||
    from transformers import top_k_top_p_filtering
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
					from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
				
			||||||
| 
						 | 
					@ -164,6 +156,15 @@ def deepmind_sample(logits, return_probs: bool=False, top_k: int=50,
 | 
				
			||||||
def logits_to_probs(logits, top_k: int=50, top_p: float=0.7, temperature: float=0.7):
 | 
					def logits_to_probs(logits, top_k: int=50, top_p: float=0.7, temperature: float=0.7):
 | 
				
			||||||
    invalidInputError(top_k != 1 and top_p != 0.0 and temperature != 0.0,
 | 
					    invalidInputError(top_k != 1 and top_p != 0.0 and temperature != 0.0,
 | 
				
			||||||
                      "top_k != 1 and top_p != 0.0 and temperature != 0.0 if do_sample=True")
 | 
					                      "top_k != 1 and top_p != 0.0 and temperature != 0.0 if do_sample=True")
 | 
				
			||||||
 | 
					    if version.parse(trans_version) >= version.parse("4.39.0"):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            from trl.core import top_k_top_p_filtering
 | 
				
			||||||
 | 
					        except ModuleNotFoundError:
 | 
				
			||||||
 | 
					            log4Error.invalidInputError(False,
 | 
				
			||||||
 | 
					                                        "For transformers version >= 4.39.0, "
 | 
				
			||||||
 | 
					                                        "pip install trl==0.11.0")
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        from transformers import top_k_top_p_filtering
 | 
				
			||||||
    _logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
 | 
					    _logits = top_k_top_p_filtering(logits.view(-1, logits.size(-1)) / temperature,
 | 
				
			||||||
                                    top_k=top_k, top_p=top_p)
 | 
					                                    top_k=top_k, top_p=top_p)
 | 
				
			||||||
    prob_list = _logits.softmax(-1)
 | 
					    prob_list = _logits.softmax(-1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue