NPU] Update prompt format for baichuan2-pipeline (#12625)
This commit is contained in:
		
							parent
							
								
									34dbdb8ee3
								
							
						
					
					
						commit
						5f04ed7254
					
				
					 1 changed files with 7 additions and 16 deletions
				
			
		| 
						 | 
					@ -25,19 +25,6 @@ from transformers.utils import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_prompt(message: str, chat_history: list[tuple[str, str]],
 | 
					 | 
				
			||||||
               system_prompt: str) -> str:
 | 
					 | 
				
			||||||
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
 | 
					 | 
				
			||||||
    # The first user input is _not_ stripped
 | 
					 | 
				
			||||||
    do_strip = False
 | 
					 | 
				
			||||||
    for user_input, response in chat_history:
 | 
					 | 
				
			||||||
        user_input = user_input.strip() if do_strip else user_input
 | 
					 | 
				
			||||||
        do_strip = True
 | 
					 | 
				
			||||||
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
 | 
					 | 
				
			||||||
    message = message.strip() if do_strip else message
 | 
					 | 
				
			||||||
    texts.append(f'{message} [/INST]')
 | 
					 | 
				
			||||||
    return ''.join(texts)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser(
 | 
					    parser = argparse.ArgumentParser(
 | 
				
			||||||
        description="Predict Tokens using `generate()` API for npu model"
 | 
					        description="Predict Tokens using `generate()` API for npu model"
 | 
				
			||||||
| 
						 | 
					@ -108,11 +95,15 @@ if __name__ == "__main__":
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
        print("finish to load")
 | 
					        print("finish to load")
 | 
				
			||||||
        for i in range(3):
 | 
					        for i in range(3):
 | 
				
			||||||
            prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
					            messages = [{"role": "system", "content": "You are a helpful assistant."},
 | 
				
			||||||
            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
					                        {"role": "user", "content": args.prompt}]
 | 
				
			||||||
 | 
					            text = tokenizer.apply_chat_template(messages,
 | 
				
			||||||
 | 
					                                                 tokenize=False,
 | 
				
			||||||
 | 
					                                                 add_generation_prompt=True)
 | 
				
			||||||
 | 
					            _input_ids = tokenizer([text], return_tensors="pt").input_ids
 | 
				
			||||||
            print("-" * 20, "Input", "-" * 20)
 | 
					            print("-" * 20, "Input", "-" * 20)
 | 
				
			||||||
            print("input length:", len(_input_ids[0]))
 | 
					            print("input length:", len(_input_ids[0]))
 | 
				
			||||||
            print(prompt)
 | 
					            print(args.prompt)
 | 
				
			||||||
            print("-" * 20, "Output", "-" * 20)
 | 
					            print("-" * 20, "Output", "-" * 20)
 | 
				
			||||||
            st = time.time()
 | 
					            st = time.time()
 | 
				
			||||||
            output = model.generate(
 | 
					            output = model.generate(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue