[NPU] Update C++ example with repetition_penalty & update Python code accordingly (#12528)
* Update c++ npu examples with repetition penalty * Fit python with updated C++ API * Style fix * Small fix * Small fix
This commit is contained in:
		
							parent
							
								
									2cce89691a
								
							
						
					
					
						commit
						dbaf4abcb3
					
				
					 3 changed files with 21 additions and 28 deletions
				
			
		| 
						 | 
					@ -98,11 +98,11 @@ std::string add_chat_history(npu_model_params model_params,
 | 
				
			||||||
    return prompt;
 | 
					    return prompt;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_inp_size,
 | 
					std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_inp_size,
 | 
				
			||||||
                         npu_model_params model_params, tokenizer_params tok_params, int32_t max_new_token, bool do_print){
 | 
					                         npu_model_params model_params, tokenizer_params tok_params, npu_generation_params generation_params, bool do_print){
 | 
				
			||||||
    auto start = std::chrono::high_resolution_clock::now();
 | 
					    auto start = std::chrono::high_resolution_clock::now();
 | 
				
			||||||
    float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size);
 | 
					    float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size,
 | 
				
			||||||
 | 
					                                generation_params.repetition_penalty);
 | 
				
			||||||
    int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
					    int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
				
			||||||
    auto end = std::chrono::high_resolution_clock::now();
 | 
					    auto end = std::chrono::high_resolution_clock::now();
 | 
				
			||||||
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
 | 
					    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
 | 
				
			||||||
| 
						 | 
					@ -115,8 +115,9 @@ std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_i
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int token_nums = 0;
 | 
					    int token_nums = 0;
 | 
				
			||||||
    start = std::chrono::high_resolution_clock::now();
 | 
					    start = std::chrono::high_resolution_clock::now();
 | 
				
			||||||
    for (int i = 1; i < max_new_token; i++){
 | 
					    for (int i = 1; i < generation_params.max_new_token; i++){
 | 
				
			||||||
        auto logits = run_decode(void_model, embd[i-1]);
 | 
					        auto logits = run_decode(void_model, embd[i-1],
 | 
				
			||||||
 | 
					                                 generation_params.repetition_penalty);
 | 
				
			||||||
        int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
					        int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
				
			||||||
        if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
 | 
					        if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
 | 
				
			||||||
            embd.push_back(token);
 | 
					            embd.push_back(token);
 | 
				
			||||||
| 
						 | 
					@ -207,6 +208,10 @@ int main(int argc, char ** argv) {
 | 
				
			||||||
    tokenizer_params tok_params;
 | 
					    tokenizer_params tok_params;
 | 
				
			||||||
    load_tokenizer(tok_params, params.model);
 | 
					    load_tokenizer(tok_params, params.model);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    npu_generation_params generation_params;
 | 
				
			||||||
 | 
					    load_generation_config_from_file(generation_params, params.model);
 | 
				
			||||||
 | 
					    generation_params.max_new_token = n_predict;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (cnv_mode){
 | 
					    if (cnv_mode){
 | 
				
			||||||
        std::string prompt;
 | 
					        std::string prompt;
 | 
				
			||||||
        std::string history = "";
 | 
					        std::string history = "";
 | 
				
			||||||
| 
						 | 
					@ -228,9 +233,11 @@ int main(int argc, char ** argv) {
 | 
				
			||||||
                    full_prompt = add_chat_history(model_params, prompt, "", true);
 | 
					                    full_prompt = add_chat_history(model_params, prompt, "", true);
 | 
				
			||||||
                    embd_inp = llm_tokenize(full_prompt, false);
 | 
					                    embd_inp = llm_tokenize(full_prompt, false);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                generation_params.max_new_token = model_params.kv_len - embd_inp.size();
 | 
				
			||||||
                
 | 
					                
 | 
				
			||||||
                response = run_generate(model, embd_inp.data(), embd_inp.size(),
 | 
					                response = run_generate(model, embd_inp.data(), embd_inp.size(),
 | 
				
			||||||
                                        model_params, tok_params, model_params.kv_len - embd_inp.size(), false);
 | 
					                                        model_params, tok_params, generation_params, false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                std::cout << "Assistant:";
 | 
					                std::cout << "Assistant:";
 | 
				
			||||||
                std::cout << response << std::endl;
 | 
					                std::cout << response << std::endl;
 | 
				
			||||||
| 
						 | 
					@ -251,7 +258,7 @@ int main(int argc, char ** argv) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // single text generation
 | 
					        // single text generation
 | 
				
			||||||
        std::string output = run_generate(model, embd_inp.data(), embd_inp.size(),
 | 
					        std::string output = run_generate(model, embd_inp.data(), embd_inp.size(),
 | 
				
			||||||
                                          model_params, tok_params, params.n_predict, true);
 | 
					                                          model_params, tok_params, generation_params, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        std::cout << "Output: " << std::endl;
 | 
					        std::cout << "Output: " << std::endl;
 | 
				
			||||||
        std::cout << output << std::endl;
 | 
					        std::cout << output << std::endl;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -413,7 +413,7 @@ def simple_generate(
 | 
				
			||||||
        if token in eos:
 | 
					        if token in eos:
 | 
				
			||||||
            break
 | 
					            break
 | 
				
			||||||
        token = run_decode(self.model_ptr, token, self.vocab_size,
 | 
					        token = run_decode(self.model_ptr, token, self.vocab_size,
 | 
				
			||||||
                           input_list, repetition_penalty)
 | 
					                           repetition_penalty)
 | 
				
			||||||
        if streamer is not None:
 | 
					        if streamer is not None:
 | 
				
			||||||
            # rest tokens
 | 
					            # rest tokens
 | 
				
			||||||
            streamer.put(torch.tensor([token]))
 | 
					            streamer.put(torch.tensor([token]))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,20 +48,16 @@ _lib = ctypes.cdll.LoadLibrary(_lib_path)
 | 
				
			||||||
_lib.load_model_from_file.argtypes = [ctypes.c_char_p]
 | 
					_lib.load_model_from_file.argtypes = [ctypes.c_char_p]
 | 
				
			||||||
_lib.load_model_from_file.restype = ctypes.c_void_p
 | 
					_lib.load_model_from_file.restype = ctypes.c_void_p
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int]
 | 
					_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int,
 | 
				
			||||||
 | 
					                             ctypes.c_float]
 | 
				
			||||||
_lib.run_prefill.restype = ctypes.POINTER(ctypes.c_float)
 | 
					_lib.run_prefill.restype = ctypes.POINTER(ctypes.c_float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int]
 | 
					_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_float]
 | 
				
			||||||
_lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)
 | 
					_lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
 | 
					_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
 | 
				
			||||||
_lib.llm_sample_token.restype = ctypes.c_int
 | 
					_lib.llm_sample_token.restype = ctypes.c_int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int,
 | 
					 | 
				
			||||||
                                ctypes.POINTER(ctypes.c_int), ctypes.c_int,
 | 
					 | 
				
			||||||
                                ctypes.c_float]
 | 
					 | 
				
			||||||
_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_lib.reset.argtypes = [ctypes.c_void_p]
 | 
					_lib.reset.argtypes = [ctypes.c_void_p]
 | 
				
			||||||
_lib.reset.restype = None
 | 
					_lib.reset.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -81,23 +77,13 @@ def load_model_from_file(model_dir: str):
 | 
				
			||||||
def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
 | 
					def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
 | 
				
			||||||
    input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
 | 
					    input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
 | 
				
			||||||
    input_len = len(input_ids)
 | 
					    input_len = len(input_ids)
 | 
				
			||||||
    plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
 | 
					    plogits = _lib.run_prefill(model_ptr, input_ptr, input_len, repetition_penalty)
 | 
				
			||||||
    if repetition_penalty != 1:
 | 
					 | 
				
			||||||
        plogits = _lib.process_logits(plogits, vocab_size,
 | 
					 | 
				
			||||||
                                      input_ptr, input_len,
 | 
					 | 
				
			||||||
                                      repetition_penalty)
 | 
					 | 
				
			||||||
    new_token = _lib.llm_sample_token(plogits, True, vocab_size)
 | 
					    new_token = _lib.llm_sample_token(plogits, True, vocab_size)
 | 
				
			||||||
    return new_token
 | 
					    return new_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0):
 | 
					def run_decode(model_ptr, input_id, vocab_size, repetition_penalty=1.0):
 | 
				
			||||||
    plogits = _lib.run_decode(model_ptr, input_id)
 | 
					    plogits = _lib.run_decode(model_ptr, input_id, repetition_penalty)
 | 
				
			||||||
    if repetition_penalty != 1:
 | 
					 | 
				
			||||||
        updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids)
 | 
					 | 
				
			||||||
        updated_input_len = len(updated_input_ids)
 | 
					 | 
				
			||||||
        plogits = _lib.process_logits(plogits, vocab_size,
 | 
					 | 
				
			||||||
                                      updated_input_ptr, updated_input_len,
 | 
					 | 
				
			||||||
                                      repetition_penalty)
 | 
					 | 
				
			||||||
    new_token = _lib.llm_sample_token(plogits, True, vocab_size)
 | 
					    new_token = _lib.llm_sample_token(plogits, True, vocab_size)
 | 
				
			||||||
    return new_token
 | 
					    return new_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue