Support streaming for lookup generation (#11922)
* Support streaming for lookup generation * Small update * Style fixes * Add origin generate full back for batch inference and beam search; support input length threshold judgement for directly input with input_ids * Fix lookup stream generate with eos token * Small fixes * Small fix * index fix * Small fix
This commit is contained in:
		
							parent
							
								
									a0bbd8e28d
								
							
						
					
					
						commit
						c1d07bc626
					
				
					 1 changed files with 39 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -59,15 +59,24 @@ def generate(
 | 
			
		|||
):
 | 
			
		||||
    lookahead = kwargs.pop("lookahead", None)
 | 
			
		||||
    perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
 | 
			
		||||
    if perf_mode == "1" and lookahead is None:
 | 
			
		||||
        if inputs is not None:
 | 
			
		||||
            if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
 | 
			
		||||
                lookahead = 2  # default to 2 now
 | 
			
		||||
 | 
			
		||||
    input_ids_shape = None
 | 
			
		||||
    if inputs is not None:
 | 
			
		||||
        input_ids_shape = inputs.shape
 | 
			
		||||
    else:
 | 
			
		||||
        input_ids = kwargs.get("input_ids", None)
 | 
			
		||||
        if input_ids is not None:
 | 
			
		||||
            input_ids_shape = input_ids.shape
 | 
			
		||||
        else:
 | 
			
		||||
            inputs_embeds = kwargs.get("inputs_embeds", None)
 | 
			
		||||
            if inputs_embeds is not None:
 | 
			
		||||
                if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
 | 
			
		||||
                    lookahead = 2  # default to 2 now
 | 
			
		||||
                input_ids_shape = inputs_embeds.shape
 | 
			
		||||
 | 
			
		||||
    if perf_mode == "1" and lookahead is None:
 | 
			
		||||
        if input_ids_shape is not None and \
 | 
			
		||||
                input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
 | 
			
		||||
            lookahead = 2  # default to 2 now
 | 
			
		||||
 | 
			
		||||
    if lookahead:
 | 
			
		||||
        from ipex_llm.transformers.convert import get_enable_ipex
 | 
			
		||||
        _enable_ipex = get_enable_ipex()
 | 
			
		||||
| 
						 | 
				
			
			@ -75,7 +84,15 @@ def generate(
 | 
			
		|||
        if self.device.type == "cpu" and _enable_ipex:
 | 
			
		||||
            logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
 | 
			
		||||
                           "fallback to original generate.")
 | 
			
		||||
            kwargs.pop("max_matching_ngram_size")
 | 
			
		||||
            kwargs.pop("max_matching_ngram_size", None)
 | 
			
		||||
        elif input_ids_shape is not None and input_ids_shape[0] > 1:
 | 
			
		||||
            logger.warning("Prompt lookup is currently not supported with batch inference, "
 | 
			
		||||
                           "fallback to original generate.")
 | 
			
		||||
            kwargs.pop("max_matching_ngram_size", None)
 | 
			
		||||
        elif kwargs.get("num_beams", None) not in [None, 1]:
 | 
			
		||||
            logger.warning("Prompt lookup is currently not supported with num_beams != 1, "
 | 
			
		||||
                           "fallback to original generate.")
 | 
			
		||||
            kwargs.pop("max_matching_ngram_size", None)
 | 
			
		||||
        else:
 | 
			
		||||
            # Do prompt lookup generation
 | 
			
		||||
            # If lookahead is provided, we will use lookup_generate instead of
 | 
			
		||||
| 
						 | 
				
			
			@ -94,6 +111,7 @@ def generate(
 | 
			
		|||
            return self.lookup_generate(inputs=inputs,
 | 
			
		||||
                                        num_output_tokens=lookahead,
 | 
			
		||||
                                        generation_config=generation_config,
 | 
			
		||||
                                        streamer=streamer,
 | 
			
		||||
                                        logits_processor=logits_processor,
 | 
			
		||||
                                        stopping_criteria=stopping_criteria,
 | 
			
		||||
                                        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
			
		||||
| 
						 | 
				
			
			@ -254,12 +272,19 @@ def lookup_generate(self,
 | 
			
		|||
                    num_output_tokens: int = 10,
 | 
			
		||||
                    max_matching_ngram_size: int = None,
 | 
			
		||||
                    generation_config: Optional[GenerationConfig] = None,
 | 
			
		||||
                    streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
                    attention_mask=None,
 | 
			
		||||
                    **sampling_kwargs):
 | 
			
		||||
    input_ids, generation_config, logits_processor, stopping_criteria, \
 | 
			
		||||
        model_kwargs = _prepare_generate_args(self, inputs, generation_config,
 | 
			
		||||
                                              **sampling_kwargs)
 | 
			
		||||
 | 
			
		||||
    invalidInputError(input_ids.shape[0] == 1,
 | 
			
		||||
                      "Prompt lookup is currently not supported with batch inference.")
 | 
			
		||||
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        streamer.put(input_ids.cpu())
 | 
			
		||||
 | 
			
		||||
    device_name = get_xpu_device_type(input_ids)
 | 
			
		||||
 | 
			
		||||
    candidates_generator = PromptLookupCandidateGenerator(
 | 
			
		||||
| 
						 | 
				
			
			@ -406,12 +431,19 @@ def lookup_generate(self,
 | 
			
		|||
                    first_eos_idx = out_idx
 | 
			
		||||
                    break
 | 
			
		||||
            if first_eos_idx > -1:
 | 
			
		||||
                if streamer is not None:
 | 
			
		||||
                    streamer.put(output_ids[:(first_eos_idx + 1)].cpu())
 | 
			
		||||
                step -= (len(output_ids_list) - first_eos_idx - 1)
 | 
			
		||||
                break
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            streamer.put(output_ids.cpu())
 | 
			
		||||
 | 
			
		||||
    step = min(step, max_new_tokens)
 | 
			
		||||
    e2e_toc = time.time()
 | 
			
		||||
    self.n_token_generated = step
 | 
			
		||||
    self.e2e_time_without_first = e2e_toc - e2e_tic
 | 
			
		||||
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        streamer.end()
 | 
			
		||||
 | 
			
		||||
    return input_ids[:, : input_len + step]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue