[NPU C++] Update example with conversation mode support (#12510)
This commit is contained in:
		
							parent
							
								
									0918d3baca
								
							
						
					
					
						commit
						12c78978dd
					
				
					 2 changed files with 179 additions and 35 deletions
				
			
		| 
						 | 
					@ -92,16 +92,22 @@ cd Release
 | 
				
			||||||
With built `llm-npu-cli`, you can run the example with specified paramaters. For example,
 | 
					With built `llm-npu-cli`, you can run the example with specified paramaters. For example,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```cmd
 | 
					```cmd
 | 
				
			||||||
 | 
					# Run simple text completion
 | 
				
			||||||
llm-npu-cli.exe -m <converted_model_path> -n 64 "AI是什么?"
 | 
					llm-npu-cli.exe -m <converted_model_path> -n 64 "AI是什么?"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Run in conversation mode
 | 
				
			||||||
 | 
					llm-npu-cli.exe -m <converted_model_path> -cnv
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Arguments info:
 | 
					Arguments info:
 | 
				
			||||||
- `-m` : argument defining the path of saved converted model.
 | 
					- `-m` : argument defining the path of saved converted model.
 | 
				
			||||||
 | 
					- `-cnv` : argument to enable conversation mode.
 | 
				
			||||||
- `-n` : argument defining how many tokens will be generated.
 | 
					- `-n` : argument defining how many tokens will be generated.
 | 
				
			||||||
- Last argument is your input prompt.
 | 
					- Last argument is your input prompt.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 5. Sample Output
 | 
					### 5. Sample Output
 | 
				
			||||||
#### [`Qwen/Qwen2.5-7B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct)
 | 
					#### [`Qwen/Qwen2.5-7B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct)
 | 
				
			||||||
 | 
					##### Text Completion
 | 
				
			||||||
```cmd
 | 
					```cmd
 | 
				
			||||||
Input:
 | 
					Input:
 | 
				
			||||||
<|im_start|>system
 | 
					<|im_start|>system
 | 
				
			||||||
| 
						 | 
					@ -112,10 +118,23 @@ AI是什么?<|im_end|>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Prefill 22 tokens cost xxxx ms.
 | 
					Prefill 22 tokens cost xxxx ms.
 | 
				
			||||||
Output:
 | 
					 | 
				
			||||||
AI是"人工智能"的缩写,是英文"Artificial Intelligence"的翻译。它是研究如何使计算机也具有智能的一种技术和理论。简而言之,人工智能就是让计算机能够模仿人智能行为的一项技术。
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
Decode 46 tokens cost xxxx ms (avg xx.xx ms each token).
 | 
					Decode 63 tokens cost xxxx ms (avg xx.xx ms each token).
 | 
				
			||||||
 | 
					Output:
 | 
				
			||||||
 | 
					AI是"人工智能"的缩写,它是一门研究计算机如何能够完成与人类智能相关任务的学科,包括学习、推理、自我修正等能力。简而言之,人工智能就是让计算机模拟或执行人类智能行为的理论、技术和方法。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					它涵盖了机器学习、深度学习、自然
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					##### Conversation
 | 
				
			||||||
 | 
					```cmd
 | 
				
			||||||
 | 
					User:你好
 | 
				
			||||||
 | 
					Assistant:你好!很高兴能为你提供帮助。有什么问题或需要聊天可以找我哦。
 | 
				
			||||||
 | 
					User:AI是什么?
 | 
				
			||||||
 | 
					Assistant: AI代表的是"Artificial Intelligence",中文翻译为人工智能。它是指由计算机或信息技术实现的智能行为。广义的人工智能可以指任何表现出智能行为的计算机或软件系统。狭义的人工智能则指的是模拟、学习、推理、理解自然语言以及自我生成的人工智能系统。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					简而言之,人工智能是一种利用计算机和机器来模仿、模拟或扩展人类智能的技术或系统。它包括机器学习、深度学习、自然语言处理等多个子领域。
 | 
				
			||||||
 | 
					User:exit
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Troubleshooting
 | 
					### Troubleshooting
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,11 +25,119 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static void print_usage(int, char ** argv) {
 | 
					static void print_usage(int, char ** argv) {
 | 
				
			||||||
    printf("\nexample usage:\n");
 | 
					    printf("\nexample usage:\n");
 | 
				
			||||||
    printf("\n    %s -m npu_model_dir [-n n_predict] [prompt]\n", argv[0]);
 | 
					    printf("\n    %s -m npu_model_dir [-cnv] [-n n_predict] [prompt]\n", argv[0]);
 | 
				
			||||||
    printf("\n");
 | 
					    printf("\n");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const std::string llama2_template = "<s>[INST] <<SYS>>\n\n<</SYS>>\n\n%s [/INST]";
 | 
				
			||||||
 | 
					const std::string llama3_template = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
 | 
				
			||||||
 | 
					const std::string minicpm_template = "<用户>%s<AI>";
 | 
				
			||||||
 | 
					const std::string qwen2_template = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::string add_chat_history(npu_model_params model_params,
 | 
				
			||||||
 | 
					                             std::string new_prompt, std::string chat_history, bool is_input) {
 | 
				
			||||||
 | 
					    char prompt[8092];
 | 
				
			||||||
 | 
					    if (model_params.model_type == std::string("llama") && model_params.vocab_size == 32000) {
 | 
				
			||||||
 | 
					        if (chat_history == ""){
 | 
				
			||||||
 | 
					            sprintf_s(prompt, llama2_template.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					        }else{
 | 
				
			||||||
 | 
					            if (is_input){
 | 
				
			||||||
 | 
					                std::string input_template = "%s%s [/INST]";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            else{
 | 
				
			||||||
 | 
					                std::string res_template = "%s%s </s><s>[INST]";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    } else if (model_params.model_type == std::string("llama")) {
 | 
				
			||||||
 | 
					        if (chat_history == ""){
 | 
				
			||||||
 | 
					            sprintf_s(prompt, llama3_template.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					        }else{
 | 
				
			||||||
 | 
					            if (is_input){
 | 
				
			||||||
 | 
					                std::string input_template = "%s<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            else{
 | 
				
			||||||
 | 
					                std::string res_template = "%s<|start_header_id|>assistant<|end_header_id|>\n\n%s<|eot_id|>";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    } else if (model_params.model_type == std::string("qwen2")) {
 | 
				
			||||||
 | 
					        if (chat_history == ""){
 | 
				
			||||||
 | 
					            sprintf_s(prompt, qwen2_template.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					        }else{
 | 
				
			||||||
 | 
					            if (is_input){
 | 
				
			||||||
 | 
					                std::string input_template = "%s%s<|im_end|>\n<|im_start|>assistant";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            else{
 | 
				
			||||||
 | 
					                std::string res_template = "%s%s<|im_end|>\n<|im_start|>user\n";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    } else if (model_params.model_type == std::string("minicpm")) {
 | 
				
			||||||
 | 
					        if (chat_history == ""){
 | 
				
			||||||
 | 
					            sprintf_s(prompt, minicpm_template.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					        }else{
 | 
				
			||||||
 | 
					            if (is_input){
 | 
				
			||||||
 | 
					                std::string input_template = "%s%s<AI>";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            else{
 | 
				
			||||||
 | 
					                std::string res_template = "%s%s<用户>";
 | 
				
			||||||
 | 
					                sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					        sprintf_s(prompt, chat_history.c_str(), new_prompt.c_str());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return prompt;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_inp_size,
 | 
				
			||||||
 | 
					                         npu_model_params model_params, tokenizer_params tok_params, int32_t max_new_token, bool do_print){
 | 
				
			||||||
 | 
					    auto start = std::chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size);
 | 
				
			||||||
 | 
					    int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
				
			||||||
 | 
					    auto end = std::chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
 | 
				
			||||||
 | 
					    if (do_print){
 | 
				
			||||||
 | 
					        printf("\nPrefill %d tokens cost %d ms.\n", embd_inp_size, duration.count());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<int32_t> embd;  // output ids
 | 
				
			||||||
 | 
					    embd.push_back(token);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int token_nums = 0;
 | 
				
			||||||
 | 
					    start = std::chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    for (int i = 1; i < max_new_token; i++){
 | 
				
			||||||
 | 
					        auto logits = run_decode(void_model, embd[i-1]);
 | 
				
			||||||
 | 
					        int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
				
			||||||
 | 
					        if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
 | 
				
			||||||
 | 
					            embd.push_back(token);
 | 
				
			||||||
 | 
					            token_nums ++;
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    end = std::chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::string output = llm_decode(embd);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (do_print){
 | 
				
			||||||
 | 
					        printf("\nDecode %d tokens cost %d ms (avg %f ms each token).\n", token_nums, duration.count(), (float)duration.count() / token_nums);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return output;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int main(int argc, char ** argv) {
 | 
					int main(int argc, char ** argv) {
 | 
				
			||||||
    common_params params;
 | 
					    common_params params;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,6 +147,7 @@ int main(int argc, char ** argv) {
 | 
				
			||||||
    std::string prompt = "AI是什么?";
 | 
					    std::string prompt = "AI是什么?";
 | 
				
			||||||
    // number of tokens to predict
 | 
					    // number of tokens to predict
 | 
				
			||||||
    int n_predict = 32;
 | 
					    int n_predict = 32;
 | 
				
			||||||
 | 
					    bool cnv_mode = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // parse command line arguments
 | 
					    // parse command line arguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,7 +161,11 @@ int main(int argc, char ** argv) {
 | 
				
			||||||
                    print_usage(argc, argv);
 | 
					                    print_usage(argc, argv);
 | 
				
			||||||
                    return 1;
 | 
					                    return 1;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            } else if (strcmp(argv[i], "-n") == 0) {
 | 
					            } else if (strcmp(argv[i], "-cnv") == 0){
 | 
				
			||||||
 | 
					                // multi-round conversation mode
 | 
				
			||||||
 | 
					                cnv_mode = true;
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            }else if (strcmp(argv[i], "-n") == 0) {
 | 
				
			||||||
                if (i + 1 < argc) {
 | 
					                if (i + 1 < argc) {
 | 
				
			||||||
                    try {
 | 
					                    try {
 | 
				
			||||||
                        n_predict = std::stoi(argv[++i]);
 | 
					                        n_predict = std::stoi(argv[++i]);
 | 
				
			||||||
| 
						 | 
					@ -86,6 +199,7 @@ int main(int argc, char ** argv) {
 | 
				
			||||||
    params.model = model_dir;
 | 
					    params.model = model_dir;
 | 
				
			||||||
    params.prompt = prompt;
 | 
					    params.prompt = prompt;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // npu_model_params model_params;
 | 
				
			||||||
    void* model = load_model_from_file(params.model);
 | 
					    void* model = load_model_from_file(params.model);
 | 
				
			||||||
    npu_model_params model_params;
 | 
					    npu_model_params model_params;
 | 
				
			||||||
    load_config_from_file(model_params, params.model);
 | 
					    load_config_from_file(model_params, params.model);
 | 
				
			||||||
| 
						 | 
					@ -93,43 +207,54 @@ int main(int argc, char ** argv) {
 | 
				
			||||||
    tokenizer_params tok_params;
 | 
					    tokenizer_params tok_params;
 | 
				
			||||||
    load_tokenizer(tok_params, params.model);
 | 
					    load_tokenizer(tok_params, params.model);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::string full_prompt = add_chat_template(model_params, params.prompt);
 | 
					    if (cnv_mode){
 | 
				
			||||||
    std::cout << "Input: " << std::endl;
 | 
					        std::string prompt;
 | 
				
			||||||
    std::cout << full_prompt << std::endl;
 | 
					        std::string history = "";
 | 
				
			||||||
 | 
					        std::string response;
 | 
				
			||||||
 | 
					        while(true){
 | 
				
			||||||
 | 
					            std::cout << "User:";
 | 
				
			||||||
 | 
					            std::getline(std::cin, prompt);
 | 
				
			||||||
 | 
					            if (prompt == "exit"){
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            else{
 | 
				
			||||||
 | 
					                // process prompt with chat history
 | 
				
			||||||
 | 
					                std::string full_prompt = add_chat_history(model_params, prompt, history, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // tokenize input
 | 
					                // tokenize input
 | 
				
			||||||
    std::vector<int32_t> embd_inp = llm_tokenize(full_prompt, false);
 | 
					                std::vector<int32_t> embd_inp = llm_tokenize(full_prompt, false);
 | 
				
			||||||
 | 
					                if (embd_inp.size() > model_params.max_prompt_len){
 | 
				
			||||||
 | 
					                    // empty chat history
 | 
				
			||||||
 | 
					                    full_prompt = add_chat_history(model_params, prompt, "", true);
 | 
				
			||||||
 | 
					                    embd_inp = llm_tokenize(full_prompt, false);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                response = run_generate(model, embd_inp.data(), embd_inp.size(),
 | 
				
			||||||
 | 
					                                        model_params, tok_params, model_params.kv_len - embd_inp.size(), false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<int32_t> embd;  // output ids
 | 
					                std::cout << "Assistant:";
 | 
				
			||||||
    auto start = std::chrono::high_resolution_clock::now();
 | 
					                std::cout << response << std::endl;
 | 
				
			||||||
    float* logits = run_prefill(model, embd_inp.data(), embd_inp.size());
 | 
					 | 
				
			||||||
    int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
					 | 
				
			||||||
    auto end = std::chrono::high_resolution_clock::now();
 | 
					 | 
				
			||||||
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
 | 
					 | 
				
			||||||
    printf("\nPrefill %d tokens cost %d ms.\n", embd_inp.size(), duration.count());
 | 
					 | 
				
			||||||
    embd.push_back(token);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int token_nums = 0;
 | 
					                history = add_chat_history(model_params, response, full_prompt, false);
 | 
				
			||||||
    start = std::chrono::high_resolution_clock::now();
 | 
					
 | 
				
			||||||
    for (int i = 1; i < params.n_predict; i++){
 | 
					                reset(model);
 | 
				
			||||||
        auto logits = run_decode(model, embd[i-1]);
 | 
					            }
 | 
				
			||||||
        int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
 | 
					 | 
				
			||||||
        if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
 | 
					 | 
				
			||||||
            embd.push_back(token);
 | 
					 | 
				
			||||||
            token_nums ++;
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            break;
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    end = std::chrono::high_resolution_clock::now();
 | 
					    else{
 | 
				
			||||||
    duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
 | 
					        std::string full_prompt = add_chat_template(model_params, params.prompt);
 | 
				
			||||||
 | 
					        std::cout << "Input: " << std::endl;
 | 
				
			||||||
 | 
					        std::cout << full_prompt << std::endl;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::string output = llm_decode(embd);
 | 
					        // tokenize input
 | 
				
			||||||
 | 
					        std::vector<int32_t> embd_inp = llm_tokenize(full_prompt, false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::cout << "Output: " << std::endl;
 | 
					        // single text generation
 | 
				
			||||||
    std::cout << output << std::endl;
 | 
					        std::string output = run_generate(model, embd_inp.data(), embd_inp.size(),
 | 
				
			||||||
 | 
					                                          model_params, tok_params, params.n_predict, true);
 | 
				
			||||||
    printf("\nDecode %d tokens cost %d ms (avg %f ms each token).\n", token_nums, duration.count(), (float)duration.count() / token_nums);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::cout << "Output: " << std::endl;
 | 
				
			||||||
 | 
					        std::cout << output << std::endl;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    return 0;
 | 
					    return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue