[NPU] Support repetition penalty for simple generate, Python (cpp backend) (#12522)
* Initial support of repetition penalty on NPU (cpp backend) for simple generate * Bug fix for generation config and others * Remove unnecessary print and style fix * Remove unnecessary print * Fix based on comments
This commit is contained in:
parent
77404d2a63
commit
68f2873bd3
2 changed files with 41 additions and 17 deletions
|
|
@ -355,7 +355,7 @@ def simple_generate(
|
||||||
do_print = kwargs.pop("do_print", False)
|
do_print = kwargs.pop("do_print", False)
|
||||||
time_t1 = time.perf_counter()
|
time_t1 = time.perf_counter()
|
||||||
new_generate_kwargs = {}
|
new_generate_kwargs = {}
|
||||||
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
|
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id', 'repetition_penalty']:
|
||||||
value = kwargs.pop(var, None)
|
value = kwargs.pop(var, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
new_generate_kwargs[var] = value
|
new_generate_kwargs[var] = value
|
||||||
|
|
@ -376,39 +376,48 @@ def simple_generate(
|
||||||
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
|
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
|
||||||
"Input plus output tokens should not exceed max_context_len.")
|
"Input plus output tokens should not exceed max_context_len.")
|
||||||
|
|
||||||
if "eos_token_id" not in new_generate_kwargs:
|
if "eos_token_id" in new_generate_kwargs:
|
||||||
generation_config = GenerationConfig.from_model_config(self.config)
|
eos = new_generate_kwargs["eos_token_id"]
|
||||||
if hasattr(generation_config, "eos_token_id"):
|
elif hasattr(self.generation_config, "eos_token_id"):
|
||||||
eos = generation_config.eos_token_id
|
eos = self.generation_config.eos_token_id
|
||||||
else:
|
else:
|
||||||
eos = 0xffffffff
|
eos = 0xffffffff
|
||||||
else:
|
|
||||||
eos = new_generate_kwargs["eos_token_id"]
|
|
||||||
|
|
||||||
if not isinstance(eos, list):
|
if not isinstance(eos, list):
|
||||||
eos = [eos]
|
eos = [eos]
|
||||||
|
|
||||||
output_tokens = []
|
if "repetition_penalty" in new_generate_kwargs:
|
||||||
|
repetition_penalty = new_generate_kwargs['repetition_penalty']
|
||||||
|
elif hasattr(self.generation_config, "repetition_penalty"):
|
||||||
|
repetition_penalty = self.generation_config.repetition_penalty
|
||||||
|
else:
|
||||||
|
repetition_penalty = 1
|
||||||
|
|
||||||
|
invalidInputError(repetition_penalty > 0,
|
||||||
|
"repetition_penalty should be a positive value. "
|
||||||
|
f"But you have: {repetition_penalty}")
|
||||||
|
|
||||||
from .npu_llm_cpp import run_decode, run_prefill, reset
|
from .npu_llm_cpp import run_decode, run_prefill, reset
|
||||||
|
|
||||||
token = run_prefill(self.model_ptr, input_list, self.vocab_size)
|
token = run_prefill(self.model_ptr, input_list, self.vocab_size,
|
||||||
|
repetition_penalty)
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
# 1st tokens
|
# 1st tokens
|
||||||
streamer.put(torch.tensor([token]))
|
streamer.put(torch.tensor([token]))
|
||||||
idx = 1
|
idx = 1
|
||||||
time_t2 = time.perf_counter()
|
time_t2 = time.perf_counter()
|
||||||
output_tokens.append(torch.tensor([token]))
|
input_list.append(token)
|
||||||
for i in range(new_tokens - 1):
|
for i in range(new_tokens - 1):
|
||||||
if token in eos:
|
if token in eos:
|
||||||
break
|
break
|
||||||
token = run_decode(self.model_ptr, token, self.vocab_size)
|
token = run_decode(self.model_ptr, token, self.vocab_size,
|
||||||
|
input_list, repetition_penalty)
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
# rest tokens
|
# rest tokens
|
||||||
streamer.put(torch.tensor([token]))
|
streamer.put(torch.tensor([token]))
|
||||||
idx += 1
|
idx += 1
|
||||||
output_tokens.append(torch.tensor([token]))
|
input_list.append(token)
|
||||||
output = torch.stack(output_tokens, dim=1)
|
output = torch.tensor([input_list])
|
||||||
output = torch.cat((inputs, output), dim=1)
|
|
||||||
time_t3 = time.perf_counter()
|
time_t3 = time.perf_counter()
|
||||||
|
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,11 @@ _lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
|
_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
|
||||||
_lib.llm_sample_token.restype = ctypes.c_int
|
_lib.llm_sample_token.restype = ctypes.c_int
|
||||||
|
|
||||||
|
_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int,
|
||||||
|
ctypes.POINTER(ctypes.c_int), ctypes.c_int,
|
||||||
|
ctypes.c_float]
|
||||||
|
_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
|
||||||
_lib.reset.argtypes = [ctypes.c_void_p]
|
_lib.reset.argtypes = [ctypes.c_void_p]
|
||||||
_lib.reset.restype = None
|
_lib.reset.restype = None
|
||||||
|
|
||||||
|
|
@ -73,16 +78,26 @@ def load_model_from_file(model_dir: str):
|
||||||
return _lib.load_model_from_file(model_dir.encode('utf-8'))
|
return _lib.load_model_from_file(model_dir.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def run_prefill(model_ptr, input_ids, vocab_size):
|
def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
|
||||||
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
|
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
|
||||||
input_len = len(input_ids)
|
input_len = len(input_ids)
|
||||||
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
|
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
|
||||||
|
if repetition_penalty != 1:
|
||||||
|
plogits = _lib.process_logits(plogits, vocab_size,
|
||||||
|
input_ptr, input_len,
|
||||||
|
repetition_penalty)
|
||||||
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
|
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
|
||||||
return new_token
|
return new_token
|
||||||
|
|
||||||
|
|
||||||
def run_decode(model_ptr, input_id, vocab_size):
|
def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0):
|
||||||
plogits = _lib.run_decode(model_ptr, input_id)
|
plogits = _lib.run_decode(model_ptr, input_id)
|
||||||
|
if repetition_penalty != 1:
|
||||||
|
updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids)
|
||||||
|
updated_input_len = len(updated_input_ids)
|
||||||
|
plogits = _lib.process_logits(plogits, vocab_size,
|
||||||
|
updated_input_ptr, updated_input_len,
|
||||||
|
repetition_penalty)
|
||||||
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
|
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
|
||||||
return new_token
|
return new_token
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue