[NPU L0] Update streaming mode of example (#12312)
This commit is contained in:
		
							parent
							
								
									126f95be80
								
							
						
					
					
						commit
						d409d9d0eb
					
				
					 7 changed files with 83 additions and 62 deletions
				
			
		| 
						 | 
					@ -72,28 +72,21 @@ Arguments info:
 | 
				
			||||||
- `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
 | 
					- `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
 | 
				
			||||||
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`.
 | 
					- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`.
 | 
				
			||||||
- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache.
 | 
					- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache.
 | 
				
			||||||
 | 
					- `--disable-streaming`: Disable streaming mode of generation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Sample Output
 | 
					### Sample Output of Streaming Mode
 | 
				
			||||||
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
 | 
					#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
```log
 | 
					```log
 | 
				
			||||||
 Number of input tokens: 28
 | 
					 | 
				
			||||||
 Generated tokens: 32
 | 
					 | 
				
			||||||
 First token generation time: xxxx s
 | 
					 | 
				
			||||||
 Generation average latency: xxxx ms, (xxxx token/s)
 | 
					 | 
				
			||||||
 Generation time: xxxx s
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Inference time: xxxx s
 | 
					 | 
				
			||||||
-------------------- Input --------------------
 | 
					-------------------- Input --------------------
 | 
				
			||||||
<s><s> [INST] <<SYS>>
 | 
					input length: 28
 | 
				
			||||||
 | 
					<s>[INST] <<SYS>>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<</SYS>>
 | 
					<</SYS>>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
What is AI? [/INST]
 | 
					What is AI? [/INST]
 | 
				
			||||||
-------------------- Output --------------------
 | 
					-------------------- Output --------------------
 | 
				
			||||||
<s><s> [INST] <<SYS>>
 | 
					 AI (Artificial Intelligence) is a field of computer science and technology that focuses on the development of intelligent machines that can perform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<</SYS>>
 | 
					Inference time: xxxx s
 | 
				
			||||||
 | 
					 | 
				
			||||||
What is AI? [/INST]  AI (Artificial Intelligence) is a field of computer science and technology that focuses on the development of intelligent machines that can perform
 | 
					 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,7 @@ import torch
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
					from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
				
			||||||
from transformers import AutoTokenizer
 | 
					from transformers import AutoTokenizer, TextStreamer
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -61,6 +61,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
					    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
				
			||||||
    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
					    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
				
			||||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
					    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
				
			||||||
 | 
					    parser.add_argument("--disable-streaming", action="store_true", default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    model_path = args.repo_id_or_model_path
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
| 
						 | 
					@ -92,6 +93,11 @@ if __name__ == "__main__":
 | 
				
			||||||
    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
					    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
				
			||||||
        model.save_low_bit(args.lowbit_path)
 | 
					        model.save_low_bit(args.lowbit_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.disable_streaming:
 | 
				
			||||||
 | 
					        streamer = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DEFAULT_SYSTEM_PROMPT = """\
 | 
					    DEFAULT_SYSTEM_PROMPT = """\
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -99,22 +105,22 @@ if __name__ == "__main__":
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
        print("finish to load")
 | 
					        print("finish to load")
 | 
				
			||||||
        for i in range(5):
 | 
					        for i in range(3):
 | 
				
			||||||
            prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
					            prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
				
			||||||
            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
					            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					            print("-" * 20, "Input", "-" * 20)
 | 
				
			||||||
            print("input length:", len(_input_ids[0]))
 | 
					            print("input length:", len(_input_ids[0]))
 | 
				
			||||||
 | 
					            print(prompt)
 | 
				
			||||||
 | 
					            print("-" * 20, "Output", "-" * 20)
 | 
				
			||||||
            st = time.time()
 | 
					            st = time.time()
 | 
				
			||||||
            output = model.generate(
 | 
					            output = model.generate(
 | 
				
			||||||
                _input_ids, max_new_tokens=args.n_predict, do_print=True
 | 
					                _input_ids, max_new_tokens=args.n_predict, streamer=streamer
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            end = time.time()
 | 
					            end = time.time()
 | 
				
			||||||
 | 
					            if args.disable_streaming:
 | 
				
			||||||
 | 
					                output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
				
			||||||
 | 
					                print(output_str)
 | 
				
			||||||
            print(f"Inference time: {end-st} s")
 | 
					            print(f"Inference time: {end-st} s")
 | 
				
			||||||
            input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Input", "-" * 20)
 | 
					 | 
				
			||||||
            print(input_str)
 | 
					 | 
				
			||||||
            output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Output", "-" * 20)
 | 
					 | 
				
			||||||
            print(output_str)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,7 @@ import torch
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
					from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
				
			||||||
from transformers import AutoTokenizer
 | 
					from transformers import AutoTokenizer, TextStreamer
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -62,6 +62,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
					    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
				
			||||||
    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
					    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
				
			||||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
					    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
				
			||||||
 | 
					    parser.add_argument("--disable-streaming", action="store_true", default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    model_path = args.repo_id_or_model_path
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
| 
						 | 
					@ -91,6 +92,11 @@ if __name__ == "__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
					    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
				
			||||||
        model.save_low_bit(args.lowbit_path)
 | 
					        model.save_low_bit(args.lowbit_path)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    if args.disable_streaming:
 | 
				
			||||||
 | 
					        streamer = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DEFAULT_SYSTEM_PROMPT = """\
 | 
					    DEFAULT_SYSTEM_PROMPT = """\
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					@ -99,22 +105,22 @@ if __name__ == "__main__":
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
        print("finish to load")
 | 
					        print("finish to load")
 | 
				
			||||||
        for i in range(5):
 | 
					        for i in range(3):
 | 
				
			||||||
            prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
					            prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
				
			||||||
            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
					            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					            print("-" * 20, "Input", "-" * 20)
 | 
				
			||||||
            print("input length:", len(_input_ids[0]))
 | 
					            print("input length:", len(_input_ids[0]))
 | 
				
			||||||
 | 
					            print(prompt)
 | 
				
			||||||
 | 
					            print("-" * 20, "Output", "-" * 20)
 | 
				
			||||||
            st = time.time()
 | 
					            st = time.time()
 | 
				
			||||||
            output = model.generate(
 | 
					            output = model.generate(
 | 
				
			||||||
                _input_ids, max_new_tokens=args.n_predict, do_print=True
 | 
					                _input_ids, max_new_tokens=args.n_predict, streamer=streamer
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            end = time.time()
 | 
					            end = time.time()
 | 
				
			||||||
 | 
					            if args.disable_streaming:
 | 
				
			||||||
 | 
					                output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
				
			||||||
 | 
					                print(output_str)
 | 
				
			||||||
            print(f"Inference time: {end-st} s")
 | 
					            print(f"Inference time: {end-st} s")
 | 
				
			||||||
            input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Input", "-" * 20)
 | 
					 | 
				
			||||||
            print(input_str)
 | 
					 | 
				
			||||||
            output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Output", "-" * 20)
 | 
					 | 
				
			||||||
            print(output_str)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,7 @@ import torch
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
					from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
				
			||||||
from transformers import AutoTokenizer
 | 
					from transformers import AutoTokenizer, TextStreamer
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -68,6 +68,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
					    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
				
			||||||
    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
					    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
				
			||||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
					    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
				
			||||||
 | 
					    parser.add_argument("--disable-streaming", action="store_true", default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    model_path = args.repo_id_or_model_path
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
| 
						 | 
					@ -98,26 +99,31 @@ if __name__ == "__main__":
 | 
				
			||||||
    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
					    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
				
			||||||
        model.save_low_bit(args.lowbit_path)
 | 
					        model.save_low_bit(args.lowbit_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.disable_streaming:
 | 
				
			||||||
 | 
					        streamer = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
        print("finish to load")
 | 
					        print("finish to load")
 | 
				
			||||||
        for i in range(5):
 | 
					        for i in range(3):
 | 
				
			||||||
            prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
					            prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
				
			||||||
            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
					            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					            print("-" * 20, "Input", "-" * 20)
 | 
				
			||||||
            print("input length:", len(_input_ids[0]))
 | 
					            print("input length:", len(_input_ids[0]))
 | 
				
			||||||
 | 
					            print(prompt)
 | 
				
			||||||
 | 
					            print("-" * 20, "Output", "-" * 20)
 | 
				
			||||||
            st = time.time()
 | 
					            st = time.time()
 | 
				
			||||||
            output = model.generate(
 | 
					            output = model.generate(
 | 
				
			||||||
                _input_ids, max_new_tokens=args.n_predict, do_print=True
 | 
					                _input_ids, max_new_tokens=args.n_predict, streamer=streamer
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            end = time.time()
 | 
					            end = time.time()
 | 
				
			||||||
 | 
					            if args.disable_streaming:
 | 
				
			||||||
 | 
					                output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
				
			||||||
 | 
					                print(output_str)
 | 
				
			||||||
            print(f"Inference time: {end-st} s")
 | 
					            print(f"Inference time: {end-st} s")
 | 
				
			||||||
            input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Input", "-" * 20)
 | 
					 | 
				
			||||||
            print(input_str)
 | 
					 | 
				
			||||||
            output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Output", "-" * 20)
 | 
					 | 
				
			||||||
            print(output_str)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,7 @@ import torch
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
					from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
				
			||||||
from transformers import AutoTokenizer
 | 
					from transformers import AutoTokenizer, TextStreamer
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,6 +48,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
					    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
				
			||||||
    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
					    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
				
			||||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
					    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
				
			||||||
 | 
					    parser.add_argument("--disable-streaming", action="store_true", default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    model_path = args.repo_id_or_model_path
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
| 
						 | 
					@ -79,26 +80,31 @@ if __name__ == "__main__":
 | 
				
			||||||
    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
					    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
				
			||||||
        model.save_low_bit(args.lowbit_path)
 | 
					        model.save_low_bit(args.lowbit_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.disable_streaming:
 | 
				
			||||||
 | 
					        streamer = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
        print("finish to load")
 | 
					        print("finish to load")
 | 
				
			||||||
        for i in range(5):
 | 
					        for i in range(3):
 | 
				
			||||||
            prompt = "<用户>{}<AI>".format(args.prompt)
 | 
					            prompt = "<用户>{}<AI>".format(args.prompt)
 | 
				
			||||||
            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
					            _input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					            print("-" * 20, "Input", "-" * 20)
 | 
				
			||||||
            print("input length:", len(_input_ids[0]))
 | 
					            print("input length:", len(_input_ids[0]))
 | 
				
			||||||
 | 
					            print(prompt)
 | 
				
			||||||
 | 
					            print("-" * 20, "Output", "-" * 20)
 | 
				
			||||||
            st = time.time()
 | 
					            st = time.time()
 | 
				
			||||||
            output = model.generate(
 | 
					            output = model.generate(
 | 
				
			||||||
                _input_ids, max_new_tokens=args.n_predict, do_print=True
 | 
					                _input_ids, max_new_tokens=args.n_predict, streamer=streamer
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            end = time.time()
 | 
					            end = time.time()
 | 
				
			||||||
 | 
					            if args.disable_streaming:
 | 
				
			||||||
 | 
					                output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
				
			||||||
 | 
					                print(output_str)
 | 
				
			||||||
            print(f"Inference time: {end-st} s")
 | 
					            print(f"Inference time: {end-st} s")
 | 
				
			||||||
            input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Input", "-" * 20)
 | 
					 | 
				
			||||||
            print(input_str)
 | 
					 | 
				
			||||||
            output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Output", "-" * 20)
 | 
					 | 
				
			||||||
            print(output_str)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,7 @@ import torch
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
					from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
				
			||||||
from transformers import AutoTokenizer
 | 
					from transformers import AutoTokenizer, TextStreamer
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -50,6 +50,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser.add_argument('--load_in_low_bit', type=str, default="sym_int4",
 | 
					    parser.add_argument('--load_in_low_bit', type=str, default="sym_int4",
 | 
				
			||||||
                        help='Load in low bit to use')
 | 
					                        help='Load in low bit to use')
 | 
				
			||||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
					    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
				
			||||||
 | 
					    parser.add_argument("--disable-streaming", action="store_true", default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    model_path = args.repo_id_or_model_path
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
| 
						 | 
					@ -81,6 +82,11 @@ if __name__ == "__main__":
 | 
				
			||||||
    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
					    if args.lowbit_path and not os.path.exists(args.lowbit_path):
 | 
				
			||||||
        model.save_low_bit(args.lowbit_path)
 | 
					        model.save_low_bit(args.lowbit_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.disable_streaming:
 | 
				
			||||||
 | 
					        streamer = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
    messages = [{"role": "system", "content": "You are a helpful assistant."},
 | 
					    messages = [{"role": "system", "content": "You are a helpful assistant."},
 | 
				
			||||||
| 
						 | 
					@ -90,21 +96,21 @@ if __name__ == "__main__":
 | 
				
			||||||
                                         add_generation_prompt=True)
 | 
					                                         add_generation_prompt=True)
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
        print("finish to load")
 | 
					        print("finish to load")
 | 
				
			||||||
        for i in range(5):
 | 
					        for i in range(3):
 | 
				
			||||||
            _input_ids = tokenizer([text], return_tensors="pt").input_ids
 | 
					            _input_ids = tokenizer([text], return_tensors="pt").input_ids
 | 
				
			||||||
 | 
					            print("-" * 20, "Input", "-" * 20)
 | 
				
			||||||
            print("input length:", len(_input_ids[0]))
 | 
					            print("input length:", len(_input_ids[0]))
 | 
				
			||||||
 | 
					            print(text)
 | 
				
			||||||
 | 
					            print("-" * 20, "Output", "-" * 20)
 | 
				
			||||||
            st = time.time()
 | 
					            st = time.time()
 | 
				
			||||||
            output = model.generate(
 | 
					            output = model.generate(
 | 
				
			||||||
                _input_ids, max_new_tokens=args.n_predict, do_print=True
 | 
					                _input_ids, max_new_tokens=args.n_predict, streamer=streamer
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            end = time.time()
 | 
					            end = time.time()
 | 
				
			||||||
 | 
					            if args.disable_streaming:
 | 
				
			||||||
 | 
					                output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
				
			||||||
 | 
					                print(output_str)
 | 
				
			||||||
            print(f"Inference time: {end-st} s")
 | 
					            print(f"Inference time: {end-st} s")
 | 
				
			||||||
            input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Input", "-" * 20)
 | 
					 | 
				
			||||||
            print(input_str)
 | 
					 | 
				
			||||||
            output_str = tokenizer.decode(output[0], skip_special_tokens=False)
 | 
					 | 
				
			||||||
            print("-" * 20, "Output", "-" * 20)
 | 
					 | 
				
			||||||
            print(output_str)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("-" * 80)
 | 
					    print("-" * 80)
 | 
				
			||||||
    print("done")
 | 
					    print("done")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -134,7 +134,6 @@ def generate(
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                input_pipe = open(in_pipe_path, "wb")
 | 
					                input_pipe = open(in_pipe_path, "wb")
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                print('Waiting for input pipe')
 | 
					 | 
				
			||||||
                time.sleep(1)
 | 
					                time.sleep(1)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
| 
						 | 
					@ -143,7 +142,6 @@ def generate(
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                output_pipe = open(out_pipe_path, "rb")
 | 
					                output_pipe = open(out_pipe_path, "rb")
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                print('Waiting for output pipe')
 | 
					 | 
				
			||||||
                time.sleep(1)
 | 
					                time.sleep(1)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
| 
						 | 
					@ -152,7 +150,7 @@ def generate(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        bdata = str.encode(str(temp_dir))
 | 
					        bdata = str.encode(str(temp_dir))
 | 
				
			||||||
        invalidInputError(len(bdata) <= 2000,
 | 
					        invalidInputError(len(bdata) <= 2000,
 | 
				
			||||||
                          f"Leng of input directory is too long ({len(bdata)}), "
 | 
					                          f"Length of input directory is too long ({len(bdata)}), "
 | 
				
			||||||
                          "which may cause read error.")
 | 
					                          "which may cause read error.")
 | 
				
			||||||
        input_pipe.write(bdata)
 | 
					        input_pipe.write(bdata)
 | 
				
			||||||
        input_pipe.flush()
 | 
					        input_pipe.flush()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue