[Fix] LLM: Fix condition check error for speculative decoding on CPU (#10402)
Fix condition check error for speculative decoding on CPU
This commit is contained in:
		
							parent
							
								
									f158b49835
								
							
						
					
					
						commit
						e10de2c42d
					
				
					 1 changed files with 13 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -609,17 +609,20 @@ def speculative_generate(self,
 | 
			
		|||
            draft_current_input_ids = current_input_ids
 | 
			
		||||
            # Target model KV cache to draft model
 | 
			
		||||
 | 
			
		||||
            if self.device.type == 'cpu' and (not _enable_ipex):
 | 
			
		||||
            if self.device.type == 'cpu':
 | 
			
		||||
                # init past_key_values_storage and assign initial fp32 value
 | 
			
		||||
                if step == 1:
 | 
			
		||||
                    past_key_values_storage = \
 | 
			
		||||
                        _prepare_past_key_values_storage_cpu(self, past_key_values,
 | 
			
		||||
                                                             max_new_tokens, _enable_ipex)
 | 
			
		||||
                # each iter cut off cur_len kv_cache from past_key_values1
 | 
			
		||||
                draft_past_key_values = \
 | 
			
		||||
                    _prepare_draft_past_key_values_cpu(self, past_key_values,
 | 
			
		||||
                                                       past_key_values_storage,  _enable_ipex)
 | 
			
		||||
                original_draft_past_key_values = draft_past_key_values
 | 
			
		||||
                if _enable_ipex:
 | 
			
		||||
                    draft_past_key_values = past_key_values
 | 
			
		||||
                else:
 | 
			
		||||
                    if step == 1:
 | 
			
		||||
                        past_key_values_storage = \
 | 
			
		||||
                            _prepare_past_key_values_storage_cpu(self, past_key_values,
 | 
			
		||||
                                                                 max_new_tokens, _enable_ipex)
 | 
			
		||||
                    # each iter cut off cur_len kv_cache from past_key_values1
 | 
			
		||||
                    draft_past_key_values = \
 | 
			
		||||
                        _prepare_draft_past_key_values_cpu(self, past_key_values,
 | 
			
		||||
                                                           past_key_values_storage,  _enable_ipex)
 | 
			
		||||
                    original_draft_past_key_values = draft_past_key_values
 | 
			
		||||
            else:
 | 
			
		||||
                past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
 | 
			
		||||
                                                                        max_step_draft,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue