From 68f2873bd38b6173415c155a2440bb17652eeca3 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:55:25 +0800 Subject: [PATCH] [NPU] Support repetition penalty for simple generate, Python (cpp backend) (#12522) * Initial support of repetition penalty on NPU (cpp backend) for simple generate * Bug fix for generation config and others * Remove unnecessary print and style fix * Remove unnecessary print * Fix based on comments --- .../transformers/npu_models/convert.py | 39 ++++++++++++------- .../transformers/npu_models/npu_llm_cpp.py | 19 ++++++++- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 2842799b..9d8c20ce 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -355,7 +355,7 @@ def simple_generate( do_print = kwargs.pop("do_print", False) time_t1 = time.perf_counter() new_generate_kwargs = {} - for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']: + for var in ['max_new_tokens', 'attention_mask', 'eos_token_id', 'repetition_penalty']: value = kwargs.pop(var, None) if value is not None: new_generate_kwargs[var] = value @@ -376,39 +376,48 @@ def simple_generate( invalidInputError(input_length + new_tokens <= self.kv_len + 1, "Input plus output tokens should not exceed max_context_len.") - if "eos_token_id" not in new_generate_kwargs: - generation_config = GenerationConfig.from_model_config(self.config) - if hasattr(generation_config, "eos_token_id"): - eos = generation_config.eos_token_id - else: - eos = 0xffffffff - else: + if "eos_token_id" in new_generate_kwargs: eos = new_generate_kwargs["eos_token_id"] + elif hasattr(self.generation_config, "eos_token_id"): + eos = self.generation_config.eos_token_id + else: + eos = 0xffffffff if not isinstance(eos, list): eos = [eos] - output_tokens = [] + if "repetition_penalty" in new_generate_kwargs: + repetition_penalty = new_generate_kwargs['repetition_penalty'] + elif hasattr(self.generation_config, "repetition_penalty"): + repetition_penalty = self.generation_config.repetition_penalty + else: + repetition_penalty = 1 + + invalidInputError(repetition_penalty > 0, + "repetition_penalty should be a positive value. " + f"But you have: {repetition_penalty}") + from .npu_llm_cpp import run_decode, run_prefill, reset - token = run_prefill(self.model_ptr, input_list, self.vocab_size) + token = run_prefill(self.model_ptr, input_list, self.vocab_size, + repetition_penalty) if streamer is not None: # 1st tokens streamer.put(torch.tensor([token])) idx = 1 time_t2 = time.perf_counter() - output_tokens.append(torch.tensor([token])) + input_list.append(token) for i in range(new_tokens - 1): if token in eos: break - token = run_decode(self.model_ptr, token, self.vocab_size) + token = run_decode(self.model_ptr, token, self.vocab_size, + input_list, repetition_penalty) if streamer is not None: # rest tokens streamer.put(torch.tensor([token])) idx += 1 - output_tokens.append(torch.tensor([token])) - output = torch.stack(output_tokens, dim=1) - output = torch.cat((inputs, output), dim=1) + input_list.append(token) + output = torch.tensor([input_list]) time_t3 = time.perf_counter() if streamer is not None: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py b/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py index 8dca443f..509003cf 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py @@ -57,6 +57,11 @@ _lib.run_decode.restype = ctypes.POINTER(ctypes.c_float) _lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int] _lib.llm_sample_token.restype = ctypes.c_int +_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int, + ctypes.POINTER(ctypes.c_int), ctypes.c_int, + ctypes.c_float] +_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float) + _lib.reset.argtypes = [ctypes.c_void_p] _lib.reset.restype = None @@ -73,16 +78,26 @@ def load_model_from_file(model_dir: str): return _lib.load_model_from_file(model_dir.encode('utf-8')) -def run_prefill(model_ptr, input_ids, vocab_size): +def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0): input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids) input_len = len(input_ids) plogits = _lib.run_prefill(model_ptr, input_ptr, input_len) + if repetition_penalty != 1: + plogits = _lib.process_logits(plogits, vocab_size, + input_ptr, input_len, + repetition_penalty) new_token = _lib.llm_sample_token(plogits, True, vocab_size) return new_token -def run_decode(model_ptr, input_id, vocab_size): +def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0): plogits = _lib.run_decode(model_ptr, input_id) + if repetition_penalty != 1: + updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids) + updated_input_len = len(updated_input_ids) + plogits = _lib.process_logits(plogits, vocab_size, + updated_input_ptr, updated_input_len, + repetition_penalty) new_token = _lib.llm_sample_token(plogits, True, vocab_size) return new_token