[NPU] Support repetition penalty for simple generate, Python (cpp backend) (#12522)

* Initial support of repetition penalty on NPU (cpp backend) for simple generate

* Bug fix for generation config and others

* Remove unnecessary print and style fix

* Remove unnecessary print

* Fix based on comments
This commit is contained in:
Yuwen Hu 2024-12-11 14:55:25 +08:00 committed by GitHub
parent 77404d2a63
commit 68f2873bd3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 41 additions and 17 deletions

View file

@ -355,7 +355,7 @@ def simple_generate(
do_print = kwargs.pop("do_print", False) do_print = kwargs.pop("do_print", False)
time_t1 = time.perf_counter() time_t1 = time.perf_counter()
new_generate_kwargs = {} new_generate_kwargs = {}
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']: for var in ['max_new_tokens', 'attention_mask', 'eos_token_id', 'repetition_penalty']:
value = kwargs.pop(var, None) value = kwargs.pop(var, None)
if value is not None: if value is not None:
new_generate_kwargs[var] = value new_generate_kwargs[var] = value
@ -376,39 +376,48 @@ def simple_generate(
invalidInputError(input_length + new_tokens <= self.kv_len + 1, invalidInputError(input_length + new_tokens <= self.kv_len + 1,
"Input plus output tokens should not exceed max_context_len.") "Input plus output tokens should not exceed max_context_len.")
if "eos_token_id" not in new_generate_kwargs: if "eos_token_id" in new_generate_kwargs:
generation_config = GenerationConfig.from_model_config(self.config)
if hasattr(generation_config, "eos_token_id"):
eos = generation_config.eos_token_id
else:
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"] eos = new_generate_kwargs["eos_token_id"]
elif hasattr(self.generation_config, "eos_token_id"):
eos = self.generation_config.eos_token_id
else:
eos = 0xffffffff
if not isinstance(eos, list): if not isinstance(eos, list):
eos = [eos] eos = [eos]
output_tokens = [] if "repetition_penalty" in new_generate_kwargs:
repetition_penalty = new_generate_kwargs['repetition_penalty']
elif hasattr(self.generation_config, "repetition_penalty"):
repetition_penalty = self.generation_config.repetition_penalty
else:
repetition_penalty = 1
invalidInputError(repetition_penalty > 0,
"repetition_penalty should be a positive value. "
f"But you have: {repetition_penalty}")
from .npu_llm_cpp import run_decode, run_prefill, reset from .npu_llm_cpp import run_decode, run_prefill, reset
token = run_prefill(self.model_ptr, input_list, self.vocab_size) token = run_prefill(self.model_ptr, input_list, self.vocab_size,
repetition_penalty)
if streamer is not None: if streamer is not None:
# 1st tokens # 1st tokens
streamer.put(torch.tensor([token])) streamer.put(torch.tensor([token]))
idx = 1 idx = 1
time_t2 = time.perf_counter() time_t2 = time.perf_counter()
output_tokens.append(torch.tensor([token])) input_list.append(token)
for i in range(new_tokens - 1): for i in range(new_tokens - 1):
if token in eos: if token in eos:
break break
token = run_decode(self.model_ptr, token, self.vocab_size) token = run_decode(self.model_ptr, token, self.vocab_size,
input_list, repetition_penalty)
if streamer is not None: if streamer is not None:
# rest tokens # rest tokens
streamer.put(torch.tensor([token])) streamer.put(torch.tensor([token]))
idx += 1 idx += 1
output_tokens.append(torch.tensor([token])) input_list.append(token)
output = torch.stack(output_tokens, dim=1) output = torch.tensor([input_list])
output = torch.cat((inputs, output), dim=1)
time_t3 = time.perf_counter() time_t3 = time.perf_counter()
if streamer is not None: if streamer is not None:

View file

@ -57,6 +57,11 @@ _lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)
_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int] _lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
_lib.llm_sample_token.restype = ctypes.c_int _lib.llm_sample_token.restype = ctypes.c_int
_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int,
ctypes.POINTER(ctypes.c_int), ctypes.c_int,
ctypes.c_float]
_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float)
_lib.reset.argtypes = [ctypes.c_void_p] _lib.reset.argtypes = [ctypes.c_void_p]
_lib.reset.restype = None _lib.reset.restype = None
@ -73,16 +78,26 @@ def load_model_from_file(model_dir: str):
return _lib.load_model_from_file(model_dir.encode('utf-8')) return _lib.load_model_from_file(model_dir.encode('utf-8'))
def run_prefill(model_ptr, input_ids, vocab_size): def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids) input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
input_len = len(input_ids) input_len = len(input_ids)
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len) plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
if repetition_penalty != 1:
plogits = _lib.process_logits(plogits, vocab_size,
input_ptr, input_len,
repetition_penalty)
new_token = _lib.llm_sample_token(plogits, True, vocab_size) new_token = _lib.llm_sample_token(plogits, True, vocab_size)
return new_token return new_token
def run_decode(model_ptr, input_id, vocab_size): def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0):
plogits = _lib.run_decode(model_ptr, input_id) plogits = _lib.run_decode(model_ptr, input_id)
if repetition_penalty != 1:
updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids)
updated_input_len = len(updated_input_ids)
plogits = _lib.process_logits(plogits, vocab_size,
updated_input_ptr, updated_input_len,
repetition_penalty)
new_token = _lib.llm_sample_token(plogits, True, vocab_size) new_token = _lib.llm_sample_token(plogits, True, vocab_size)
return new_token return new_token