Support streaming for lookup generation (#11922)
* Support streaming for lookup generation * Small update * Style fixes * Add origin generate full back for batch inference and beam search; support input length threshold judgement for directly input with input_ids * Fix lookup stream generate with eos token * Small fixes * Small fix * index fix * Small fix
This commit is contained in:
parent
a0bbd8e28d
commit
c1d07bc626
1 changed files with 39 additions and 7 deletions
|
|
@ -59,15 +59,24 @@ def generate(
|
||||||
):
|
):
|
||||||
lookahead = kwargs.pop("lookahead", None)
|
lookahead = kwargs.pop("lookahead", None)
|
||||||
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
||||||
if perf_mode == "1" and lookahead is None:
|
|
||||||
if inputs is not None:
|
input_ids_shape = None
|
||||||
if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
if inputs is not None:
|
||||||
lookahead = 2 # default to 2 now
|
input_ids_shape = inputs.shape
|
||||||
|
else:
|
||||||
|
input_ids = kwargs.get("input_ids", None)
|
||||||
|
if input_ids is not None:
|
||||||
|
input_ids_shape = input_ids.shape
|
||||||
else:
|
else:
|
||||||
inputs_embeds = kwargs.get("inputs_embeds", None)
|
inputs_embeds = kwargs.get("inputs_embeds", None)
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
input_ids_shape = inputs_embeds.shape
|
||||||
lookahead = 2 # default to 2 now
|
|
||||||
|
if perf_mode == "1" and lookahead is None:
|
||||||
|
if input_ids_shape is not None and \
|
||||||
|
input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
||||||
|
lookahead = 2 # default to 2 now
|
||||||
|
|
||||||
if lookahead:
|
if lookahead:
|
||||||
from ipex_llm.transformers.convert import get_enable_ipex
|
from ipex_llm.transformers.convert import get_enable_ipex
|
||||||
_enable_ipex = get_enable_ipex()
|
_enable_ipex = get_enable_ipex()
|
||||||
|
|
@ -75,7 +84,15 @@ def generate(
|
||||||
if self.device.type == "cpu" and _enable_ipex:
|
if self.device.type == "cpu" and _enable_ipex:
|
||||||
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
|
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
|
||||||
"fallback to original generate.")
|
"fallback to original generate.")
|
||||||
kwargs.pop("max_matching_ngram_size")
|
kwargs.pop("max_matching_ngram_size", None)
|
||||||
|
elif input_ids_shape is not None and input_ids_shape[0] > 1:
|
||||||
|
logger.warning("Prompt lookup is currently not supported with batch inference, "
|
||||||
|
"fallback to original generate.")
|
||||||
|
kwargs.pop("max_matching_ngram_size", None)
|
||||||
|
elif kwargs.get("num_beams", None) not in [None, 1]:
|
||||||
|
logger.warning("Prompt lookup is currently not supported with num_beams != 1, "
|
||||||
|
"fallback to original generate.")
|
||||||
|
kwargs.pop("max_matching_ngram_size", None)
|
||||||
else:
|
else:
|
||||||
# Do prompt lookup generation
|
# Do prompt lookup generation
|
||||||
# If lookahead is provided, we will use lookup_generate instead of
|
# If lookahead is provided, we will use lookup_generate instead of
|
||||||
|
|
@ -94,6 +111,7 @@ def generate(
|
||||||
return self.lookup_generate(inputs=inputs,
|
return self.lookup_generate(inputs=inputs,
|
||||||
num_output_tokens=lookahead,
|
num_output_tokens=lookahead,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
|
streamer=streamer,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
|
@ -254,12 +272,19 @@ def lookup_generate(self,
|
||||||
num_output_tokens: int = 10,
|
num_output_tokens: int = 10,
|
||||||
max_matching_ngram_size: int = None,
|
max_matching_ngram_size: int = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
**sampling_kwargs):
|
**sampling_kwargs):
|
||||||
input_ids, generation_config, logits_processor, stopping_criteria, \
|
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||||
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
||||||
**sampling_kwargs)
|
**sampling_kwargs)
|
||||||
|
|
||||||
|
invalidInputError(input_ids.shape[0] == 1,
|
||||||
|
"Prompt lookup is currently not supported with batch inference.")
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
device_name = get_xpu_device_type(input_ids)
|
device_name = get_xpu_device_type(input_ids)
|
||||||
|
|
||||||
candidates_generator = PromptLookupCandidateGenerator(
|
candidates_generator = PromptLookupCandidateGenerator(
|
||||||
|
|
@ -406,12 +431,19 @@ def lookup_generate(self,
|
||||||
first_eos_idx = out_idx
|
first_eos_idx = out_idx
|
||||||
break
|
break
|
||||||
if first_eos_idx > -1:
|
if first_eos_idx > -1:
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(output_ids[:(first_eos_idx + 1)].cpu())
|
||||||
step -= (len(output_ids_list) - first_eos_idx - 1)
|
step -= (len(output_ids_list) - first_eos_idx - 1)
|
||||||
break
|
break
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(output_ids.cpu())
|
||||||
|
|
||||||
step = min(step, max_new_tokens)
|
step = min(step, max_new_tokens)
|
||||||
e2e_toc = time.time()
|
e2e_toc = time.time()
|
||||||
self.n_token_generated = step
|
self.n_token_generated = step
|
||||||
self.e2e_time_without_first = e2e_toc - e2e_tic
|
self.e2e_time_without_first = e2e_toc - e2e_tic
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.end()
|
||||||
|
|
||||||
return input_ids[:, : input_len + step]
|
return input_ids[:, : input_len + step]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue