LLM: Add length check for IPEX-CPU speculative decoding (#10529)
Add length check for IPEX-CPU speculative decoding.
This commit is contained in:
		
							parent
							
								
									a3b007f3b1
								
							
						
					
					
						commit
						11550d3f25
					
				
					 1 changed files with 22 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -53,6 +53,28 @@ def generate(
 | 
			
		|||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    if hasattr(self, "draft_model"):
 | 
			
		||||
        from ipex_llm.llm.transformers.convert import get_enable_ipex
 | 
			
		||||
        _enable_ipex = get_enable_ipex()
 | 
			
		||||
        if _enable_ipex and inputs.size(1) < 256:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "IPEX_CPU optimized models have issues for speculative decoding with short prompts"
 | 
			
		||||
                "(length < 256). Using normal generate() method instead."
 | 
			
		||||
            )
 | 
			
		||||
            for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
 | 
			
		||||
                        'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
 | 
			
		||||
                        'th_batch_num']:
 | 
			
		||||
                value = kwargs.pop(var, None)
 | 
			
		||||
            del self.draft_model
 | 
			
		||||
            return original_generate(self,
 | 
			
		||||
                                     inputs=inputs,
 | 
			
		||||
                                     generation_config=generation_config,
 | 
			
		||||
                                     logits_processor=logits_processor,
 | 
			
		||||
                                     stopping_criteria=stopping_criteria,
 | 
			
		||||
                                     prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
			
		||||
                                     synced_gpus=synced_gpus,
 | 
			
		||||
                                     assistant_model=assistant_model,
 | 
			
		||||
                                     streamer=streamer,
 | 
			
		||||
                                     **kwargs)
 | 
			
		||||
        # do speculative decoding
 | 
			
		||||
        # TODO: maybe add other way to double check
 | 
			
		||||
        new_speculative_kwargs = {}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue