[NPU L0] Update streaming mode of example (#12312)

This commit is contained in:
binbin Deng 2024-11-01 15:38:10 +08:00 committed by GitHub
parent 126f95be80
commit d409d9d0eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 83 additions and 62 deletions

View file

@ -72,28 +72,21 @@ Arguments info:
- `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`. - `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`. - `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`.
- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache. - `--disable-transpose-value-cache`: Disable the optimization of transposing value cache.
- `--disable-streaming`: Disable streaming mode of generation.
### Sample Output ### Sample Output of Streaming Mode
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) #### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
```log ```log
Number of input tokens: 28
Generated tokens: 32
First token generation time: xxxx s
Generation average latency: xxxx ms, (xxxx token/s)
Generation time: xxxx s
Inference time: xxxx s
-------------------- Input -------------------- -------------------- Input --------------------
<s><s> [INST] <<SYS>> input length: 28
<s>[INST] <<SYS>>
<</SYS>> <</SYS>>
What is AI? [/INST] What is AI? [/INST]
-------------------- Output -------------------- -------------------- Output --------------------
<s><s> [INST] <<SYS>> AI (Artificial Intelligence) is a field of computer science and technology that focuses on the development of intelligent machines that can perform
<</SYS>> Inference time: xxxx s
What is AI? [/INST] AI (Artificial Intelligence) is a field of computer science and technology that focuses on the development of intelligent machines that can perform
``` ```

View file

@ -20,7 +20,7 @@ import torch
import time import time
import argparse import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -61,6 +61,7 @@ if __name__ == "__main__":
parser.add_argument("--max-context-len", type=int, default=1024) parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512) parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
model_path = args.repo_id_or_model_path model_path = args.repo_id_or_model_path
@ -92,6 +93,11 @@ if __name__ == "__main__":
if args.lowbit_path and not os.path.exists(args.lowbit_path): if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path) model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
DEFAULT_SYSTEM_PROMPT = """\ DEFAULT_SYSTEM_PROMPT = """\
""" """
@ -99,22 +105,22 @@ if __name__ == "__main__":
print("done") print("done")
with torch.inference_mode(): with torch.inference_mode():
print("finish to load") print("finish to load")
for i in range(5): for i in range(3):
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt") _input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("-" * 20, "Input", "-" * 20)
print("input length:", len(_input_ids[0])) print("input length:", len(_input_ids[0]))
print(prompt)
print("-" * 20, "Output", "-" * 20)
st = time.time() st = time.time()
output = model.generate( output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True _input_ids, max_new_tokens=args.n_predict, streamer=streamer
) )
end = time.time() end = time.time()
print(f"Inference time: {end-st} s") if args.disable_streaming:
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False) output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str) print(output_str)
print(f"Inference time: {end-st} s")
print("-" * 80) print("-" * 80)
print("done") print("done")

View file

@ -20,7 +20,7 @@ import torch
import time import time
import argparse import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -62,6 +62,7 @@ if __name__ == "__main__":
parser.add_argument("--max-prompt-len", type=int, default=512) parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--quantization_group_size", type=int, default=0) parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
model_path = args.repo_id_or_model_path model_path = args.repo_id_or_model_path
@ -92,6 +93,11 @@ if __name__ == "__main__":
if args.lowbit_path and not os.path.exists(args.lowbit_path): if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path) model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
DEFAULT_SYSTEM_PROMPT = """\ DEFAULT_SYSTEM_PROMPT = """\
""" """
@ -99,22 +105,22 @@ if __name__ == "__main__":
print("done") print("done")
with torch.inference_mode(): with torch.inference_mode():
print("finish to load") print("finish to load")
for i in range(5): for i in range(3):
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt") _input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("-" * 20, "Input", "-" * 20)
print("input length:", len(_input_ids[0])) print("input length:", len(_input_ids[0]))
print(prompt)
print("-" * 20, "Output", "-" * 20)
st = time.time() st = time.time()
output = model.generate( output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True _input_ids, max_new_tokens=args.n_predict, streamer=streamer
) )
end = time.time() end = time.time()
print(f"Inference time: {end-st} s") if args.disable_streaming:
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False) output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str) print(output_str)
print(f"Inference time: {end-st} s")
print("-" * 80) print("-" * 80)
print("done") print("done")

View file

@ -20,7 +20,7 @@ import torch
import time import time
import argparse import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -68,6 +68,7 @@ if __name__ == "__main__":
parser.add_argument("--max-prompt-len", type=int, default=512) parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--quantization_group_size", type=int, default=0) parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
model_path = args.repo_id_or_model_path model_path = args.repo_id_or_model_path
@ -98,26 +99,31 @@ if __name__ == "__main__":
if args.lowbit_path and not os.path.exists(args.lowbit_path): if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path) model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
print("-" * 80) print("-" * 80)
print("done") print("done")
with torch.inference_mode(): with torch.inference_mode():
print("finish to load") print("finish to load")
for i in range(5): for i in range(3):
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt") _input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("-" * 20, "Input", "-" * 20)
print("input length:", len(_input_ids[0])) print("input length:", len(_input_ids[0]))
print(prompt)
print("-" * 20, "Output", "-" * 20)
st = time.time() st = time.time()
output = model.generate( output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True _input_ids, max_new_tokens=args.n_predict, streamer=streamer
) )
end = time.time() end = time.time()
print(f"Inference time: {end-st} s") if args.disable_streaming:
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False) output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str) print(output_str)
print(f"Inference time: {end-st} s")
print("-" * 80) print("-" * 80)
print("done") print("done")

View file

@ -19,7 +19,7 @@ import torch
import time import time
import argparse import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging from transformers.utils import logging
import os import os
@ -48,6 +48,7 @@ if __name__ == "__main__":
parser.add_argument("--max-context-len", type=int, default=1024) parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512) parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
model_path = args.repo_id_or_model_path model_path = args.repo_id_or_model_path
@ -79,26 +80,31 @@ if __name__ == "__main__":
if args.lowbit_path and not os.path.exists(args.lowbit_path): if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path) model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
print("-" * 80) print("-" * 80)
print("done") print("done")
with torch.inference_mode(): with torch.inference_mode():
print("finish to load") print("finish to load")
for i in range(5): for i in range(3):
prompt = "<用户>{}<AI>".format(args.prompt) prompt = "<用户>{}<AI>".format(args.prompt)
_input_ids = tokenizer.encode(prompt, return_tensors="pt") _input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("-" * 20, "Input", "-" * 20)
print("input length:", len(_input_ids[0])) print("input length:", len(_input_ids[0]))
print(prompt)
print("-" * 20, "Output", "-" * 20)
st = time.time() st = time.time()
output = model.generate( output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True _input_ids, max_new_tokens=args.n_predict, streamer=streamer
) )
end = time.time() end = time.time()
print(f"Inference time: {end-st} s") if args.disable_streaming:
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False) output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str) print(output_str)
print(f"Inference time: {end-st} s")
print("-" * 80) print("-" * 80)
print("done") print("done")

View file

@ -20,7 +20,7 @@ import torch
import time import time
import argparse import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -50,6 +50,7 @@ if __name__ == "__main__":
parser.add_argument('--load_in_low_bit', type=str, default="sym_int4", parser.add_argument('--load_in_low_bit', type=str, default="sym_int4",
help='Load in low bit to use') help='Load in low bit to use')
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
model_path = args.repo_id_or_model_path model_path = args.repo_id_or_model_path
@ -81,6 +82,11 @@ if __name__ == "__main__":
if args.lowbit_path and not os.path.exists(args.lowbit_path): if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path) model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
print("-" * 80) print("-" * 80)
print("done") print("done")
messages = [{"role": "system", "content": "You are a helpful assistant."}, messages = [{"role": "system", "content": "You are a helpful assistant."},
@ -90,21 +96,21 @@ if __name__ == "__main__":
add_generation_prompt=True) add_generation_prompt=True)
with torch.inference_mode(): with torch.inference_mode():
print("finish to load") print("finish to load")
for i in range(5): for i in range(3):
_input_ids = tokenizer([text], return_tensors="pt").input_ids _input_ids = tokenizer([text], return_tensors="pt").input_ids
print("-" * 20, "Input", "-" * 20)
print("input length:", len(_input_ids[0])) print("input length:", len(_input_ids[0]))
print(text)
print("-" * 20, "Output", "-" * 20)
st = time.time() st = time.time()
output = model.generate( output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True _input_ids, max_new_tokens=args.n_predict, streamer=streamer
) )
end = time.time() end = time.time()
print(f"Inference time: {end-st} s") if args.disable_streaming:
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False) output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str) print(output_str)
print(f"Inference time: {end-st} s")
print("-" * 80) print("-" * 80)
print("done") print("done")

View file

@ -134,7 +134,6 @@ def generate(
try: try:
input_pipe = open(in_pipe_path, "wb") input_pipe = open(in_pipe_path, "wb")
except: except:
print('Waiting for input pipe')
time.sleep(1) time.sleep(1)
else: else:
break break
@ -143,7 +142,6 @@ def generate(
try: try:
output_pipe = open(out_pipe_path, "rb") output_pipe = open(out_pipe_path, "rb")
except: except:
print('Waiting for output pipe')
time.sleep(1) time.sleep(1)
else: else:
break break
@ -152,7 +150,7 @@ def generate(
bdata = str.encode(str(temp_dir)) bdata = str.encode(str(temp_dir))
invalidInputError(len(bdata) <= 2000, invalidInputError(len(bdata) <= 2000,
f"Leng of input directory is too long ({len(bdata)}), " f"Length of input directory is too long ({len(bdata)}), "
"which may cause read error.") "which may cause read error.")
input_pipe.write(bdata) input_pipe.write(bdata)
input_pipe.flush() input_pipe.flush()