[NPU] Support repetition penalty for simple generate, Python (cpp backend) (#12522)
* Initial support of repetition penalty on NPU (cpp backend) for simple generate * Bug fix for generation config and others * Remove unnecessary print and style fix * Remove unnecessary print * Fix based on comments
This commit is contained in:
parent
77404d2a63
commit
68f2873bd3
2 changed files with 41 additions and 17 deletions
|
|
@ -355,7 +355,7 @@ def simple_generate(
|
|||
do_print = kwargs.pop("do_print", False)
|
||||
time_t1 = time.perf_counter()
|
||||
new_generate_kwargs = {}
|
||||
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
|
||||
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id', 'repetition_penalty']:
|
||||
value = kwargs.pop(var, None)
|
||||
if value is not None:
|
||||
new_generate_kwargs[var] = value
|
||||
|
|
@ -376,39 +376,48 @@ def simple_generate(
|
|||
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
|
||||
"Input plus output tokens should not exceed max_context_len.")
|
||||
|
||||
if "eos_token_id" not in new_generate_kwargs:
|
||||
generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if hasattr(generation_config, "eos_token_id"):
|
||||
eos = generation_config.eos_token_id
|
||||
else:
|
||||
eos = 0xffffffff
|
||||
else:
|
||||
if "eos_token_id" in new_generate_kwargs:
|
||||
eos = new_generate_kwargs["eos_token_id"]
|
||||
elif hasattr(self.generation_config, "eos_token_id"):
|
||||
eos = self.generation_config.eos_token_id
|
||||
else:
|
||||
eos = 0xffffffff
|
||||
|
||||
if not isinstance(eos, list):
|
||||
eos = [eos]
|
||||
|
||||
output_tokens = []
|
||||
if "repetition_penalty" in new_generate_kwargs:
|
||||
repetition_penalty = new_generate_kwargs['repetition_penalty']
|
||||
elif hasattr(self.generation_config, "repetition_penalty"):
|
||||
repetition_penalty = self.generation_config.repetition_penalty
|
||||
else:
|
||||
repetition_penalty = 1
|
||||
|
||||
invalidInputError(repetition_penalty > 0,
|
||||
"repetition_penalty should be a positive value. "
|
||||
f"But you have: {repetition_penalty}")
|
||||
|
||||
from .npu_llm_cpp import run_decode, run_prefill, reset
|
||||
|
||||
token = run_prefill(self.model_ptr, input_list, self.vocab_size)
|
||||
token = run_prefill(self.model_ptr, input_list, self.vocab_size,
|
||||
repetition_penalty)
|
||||
if streamer is not None:
|
||||
# 1st tokens
|
||||
streamer.put(torch.tensor([token]))
|
||||
idx = 1
|
||||
time_t2 = time.perf_counter()
|
||||
output_tokens.append(torch.tensor([token]))
|
||||
input_list.append(token)
|
||||
for i in range(new_tokens - 1):
|
||||
if token in eos:
|
||||
break
|
||||
token = run_decode(self.model_ptr, token, self.vocab_size)
|
||||
token = run_decode(self.model_ptr, token, self.vocab_size,
|
||||
input_list, repetition_penalty)
|
||||
if streamer is not None:
|
||||
# rest tokens
|
||||
streamer.put(torch.tensor([token]))
|
||||
idx += 1
|
||||
output_tokens.append(torch.tensor([token]))
|
||||
output = torch.stack(output_tokens, dim=1)
|
||||
output = torch.cat((inputs, output), dim=1)
|
||||
input_list.append(token)
|
||||
output = torch.tensor([input_list])
|
||||
time_t3 = time.perf_counter()
|
||||
|
||||
if streamer is not None:
|
||||
|
|
|
|||
|
|
@ -57,6 +57,11 @@ _lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)
|
|||
_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
|
||||
_lib.llm_sample_token.restype = ctypes.c_int
|
||||
|
||||
_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int,
|
||||
ctypes.POINTER(ctypes.c_int), ctypes.c_int,
|
||||
ctypes.c_float]
|
||||
_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
_lib.reset.argtypes = [ctypes.c_void_p]
|
||||
_lib.reset.restype = None
|
||||
|
||||
|
|
@ -73,16 +78,26 @@ def load_model_from_file(model_dir: str):
|
|||
return _lib.load_model_from_file(model_dir.encode('utf-8'))
|
||||
|
||||
|
||||
def run_prefill(model_ptr, input_ids, vocab_size):
|
||||
def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
|
||||
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
|
||||
input_len = len(input_ids)
|
||||
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
|
||||
if repetition_penalty != 1:
|
||||
plogits = _lib.process_logits(plogits, vocab_size,
|
||||
input_ptr, input_len,
|
||||
repetition_penalty)
|
||||
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
|
||||
return new_token
|
||||
|
||||
|
||||
def run_decode(model_ptr, input_id, vocab_size):
|
||||
def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0):
|
||||
plogits = _lib.run_decode(model_ptr, input_id)
|
||||
if repetition_penalty != 1:
|
||||
updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids)
|
||||
updated_input_len = len(updated_input_ids)
|
||||
plogits = _lib.process_logits(plogits, vocab_size,
|
||||
updated_input_ptr, updated_input_len,
|
||||
repetition_penalty)
|
||||
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
|
||||
return new_token
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue