NPU] Update prompt format for baichuan2-pipeline (#12625)
This commit is contained in:
parent
34dbdb8ee3
commit
5f04ed7254
1 changed files with 7 additions and 16 deletions
|
|
@ -25,19 +25,6 @@ from transformers.utils import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
|
||||||
system_prompt: str) -> str:
|
|
||||||
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
|
||||||
# The first user input is _not_ stripped
|
|
||||||
do_strip = False
|
|
||||||
for user_input, response in chat_history:
|
|
||||||
user_input = user_input.strip() if do_strip else user_input
|
|
||||||
do_strip = True
|
|
||||||
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
|
||||||
message = message.strip() if do_strip else message
|
|
||||||
texts.append(f'{message} [/INST]')
|
|
||||||
return ''.join(texts)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Predict Tokens using `generate()` API for npu model"
|
description="Predict Tokens using `generate()` API for npu model"
|
||||||
|
|
@ -108,11 +95,15 @@ if __name__ == "__main__":
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
print("finish to load")
|
print("finish to load")
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
messages = [{"role": "system", "content": "You are a helpful assistant."},
|
||||||
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
{"role": "user", "content": args.prompt}]
|
||||||
|
text = tokenizer.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
_input_ids = tokenizer([text], return_tensors="pt").input_ids
|
||||||
print("-" * 20, "Input", "-" * 20)
|
print("-" * 20, "Input", "-" * 20)
|
||||||
print("input length:", len(_input_ids[0]))
|
print("input length:", len(_input_ids[0]))
|
||||||
print(prompt)
|
print(args.prompt)
|
||||||
print("-" * 20, "Output", "-" * 20)
|
print("-" * 20, "Output", "-" * 20)
|
||||||
st = time.time()
|
st = time.time()
|
||||||
output = model.generate(
|
output = model.generate(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue