[NPU] Support repetition penalty for simple generate, Python (cpp backend) (#12522)
* Initial support of repetition penalty on NPU (cpp backend) for simple generate * Bug fix for generation config and others * Remove unnecessary print and style fix * Remove unnecessary print * Fix based on comments
This commit is contained in:
		
							parent
							
								
									77404d2a63
								
							
						
					
					
						commit
						68f2873bd3
					
				
					 2 changed files with 41 additions and 17 deletions
				
			
		| 
						 | 
				
			
			@ -355,7 +355,7 @@ def simple_generate(
 | 
			
		|||
    do_print = kwargs.pop("do_print", False)
 | 
			
		||||
    time_t1 = time.perf_counter()
 | 
			
		||||
    new_generate_kwargs = {}
 | 
			
		||||
    for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
 | 
			
		||||
    for var in ['max_new_tokens', 'attention_mask', 'eos_token_id', 'repetition_penalty']:
 | 
			
		||||
        value = kwargs.pop(var, None)
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            new_generate_kwargs[var] = value
 | 
			
		||||
| 
						 | 
				
			
			@ -376,39 +376,48 @@ def simple_generate(
 | 
			
		|||
    invalidInputError(input_length + new_tokens <= self.kv_len + 1,
 | 
			
		||||
                      "Input plus output tokens should not exceed max_context_len.")
 | 
			
		||||
 | 
			
		||||
    if "eos_token_id" not in new_generate_kwargs:
 | 
			
		||||
        generation_config = GenerationConfig.from_model_config(self.config)
 | 
			
		||||
        if hasattr(generation_config, "eos_token_id"):
 | 
			
		||||
            eos = generation_config.eos_token_id
 | 
			
		||||
        else:
 | 
			
		||||
            eos = 0xffffffff
 | 
			
		||||
    else:
 | 
			
		||||
    if "eos_token_id" in new_generate_kwargs:
 | 
			
		||||
        eos = new_generate_kwargs["eos_token_id"]
 | 
			
		||||
    elif hasattr(self.generation_config, "eos_token_id"):
 | 
			
		||||
        eos = self.generation_config.eos_token_id
 | 
			
		||||
    else:
 | 
			
		||||
        eos = 0xffffffff
 | 
			
		||||
 | 
			
		||||
    if not isinstance(eos, list):
 | 
			
		||||
        eos = [eos]
 | 
			
		||||
 | 
			
		||||
    output_tokens = []
 | 
			
		||||
    if "repetition_penalty" in new_generate_kwargs:
 | 
			
		||||
        repetition_penalty = new_generate_kwargs['repetition_penalty']
 | 
			
		||||
    elif hasattr(self.generation_config, "repetition_penalty"):
 | 
			
		||||
        repetition_penalty = self.generation_config.repetition_penalty
 | 
			
		||||
    else:
 | 
			
		||||
        repetition_penalty = 1
 | 
			
		||||
 | 
			
		||||
    invalidInputError(repetition_penalty > 0,
 | 
			
		||||
                      "repetition_penalty should be a positive value. "
 | 
			
		||||
                      f"But you have: {repetition_penalty}")
 | 
			
		||||
 | 
			
		||||
    from .npu_llm_cpp import run_decode, run_prefill, reset
 | 
			
		||||
 | 
			
		||||
    token = run_prefill(self.model_ptr, input_list, self.vocab_size)
 | 
			
		||||
    token = run_prefill(self.model_ptr, input_list, self.vocab_size,
 | 
			
		||||
                        repetition_penalty)
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        # 1st tokens
 | 
			
		||||
        streamer.put(torch.tensor([token]))
 | 
			
		||||
    idx = 1
 | 
			
		||||
    time_t2 = time.perf_counter()
 | 
			
		||||
    output_tokens.append(torch.tensor([token]))
 | 
			
		||||
    input_list.append(token)
 | 
			
		||||
    for i in range(new_tokens - 1):
 | 
			
		||||
        if token in eos:
 | 
			
		||||
            break
 | 
			
		||||
        token = run_decode(self.model_ptr, token, self.vocab_size)
 | 
			
		||||
        token = run_decode(self.model_ptr, token, self.vocab_size,
 | 
			
		||||
                           input_list, repetition_penalty)
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            # rest tokens
 | 
			
		||||
            streamer.put(torch.tensor([token]))
 | 
			
		||||
        idx += 1
 | 
			
		||||
        output_tokens.append(torch.tensor([token]))
 | 
			
		||||
    output = torch.stack(output_tokens, dim=1)
 | 
			
		||||
    output = torch.cat((inputs, output), dim=1)
 | 
			
		||||
        input_list.append(token)
 | 
			
		||||
    output = torch.tensor([input_list])
 | 
			
		||||
    time_t3 = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -57,6 +57,11 @@ _lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)
 | 
			
		|||
_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
 | 
			
		||||
_lib.llm_sample_token.restype = ctypes.c_int
 | 
			
		||||
 | 
			
		||||
_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int,
 | 
			
		||||
                                ctypes.POINTER(ctypes.c_int), ctypes.c_int,
 | 
			
		||||
                                ctypes.c_float]
 | 
			
		||||
_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float)
 | 
			
		||||
 | 
			
		||||
_lib.reset.argtypes = [ctypes.c_void_p]
 | 
			
		||||
_lib.reset.restype = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -73,16 +78,26 @@ def load_model_from_file(model_dir: str):
 | 
			
		|||
    return _lib.load_model_from_file(model_dir.encode('utf-8'))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_prefill(model_ptr, input_ids, vocab_size):
 | 
			
		||||
def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
 | 
			
		||||
    input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
 | 
			
		||||
    input_len = len(input_ids)
 | 
			
		||||
    plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
 | 
			
		||||
    if repetition_penalty != 1:
 | 
			
		||||
        plogits = _lib.process_logits(plogits, vocab_size,
 | 
			
		||||
                                      input_ptr, input_len,
 | 
			
		||||
                                      repetition_penalty)
 | 
			
		||||
    new_token = _lib.llm_sample_token(plogits, True, vocab_size)
 | 
			
		||||
    return new_token
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_decode(model_ptr, input_id, vocab_size):
 | 
			
		||||
def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0):
 | 
			
		||||
    plogits = _lib.run_decode(model_ptr, input_id)
 | 
			
		||||
    if repetition_penalty != 1:
 | 
			
		||||
        updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids)
 | 
			
		||||
        updated_input_len = len(updated_input_ids)
 | 
			
		||||
        plogits = _lib.process_logits(plogits, vocab_size,
 | 
			
		||||
                                      updated_input_ptr, updated_input_len,
 | 
			
		||||
                                      repetition_penalty)
 | 
			
		||||
    new_token = _lib.llm_sample_token(plogits, True, vocab_size)
 | 
			
		||||
    return new_token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue