diff --git a/python/llm/dev/benchmark/perplexity/run_wikitext.py b/python/llm/dev/benchmark/perplexity/run_wikitext.py index 58e5f587..50991558 100644 --- a/python/llm/dev/benchmark/perplexity/run_wikitext.py +++ b/python/llm/dev/benchmark/perplexity/run_wikitext.py @@ -20,7 +20,7 @@ import argparse import torch from tqdm import tqdm -from datasets import concatenate_datasets, load_dataset +from datasets import load_dataset from ipex_llm.utils.common import invalidInputError @@ -34,6 +34,7 @@ parser.add_argument("--device", type=str, default="xpu") parser.add_argument("--precision", type=str, default="sym_int4") parser.add_argument("--use-cache", action="store_true") parser.add_argument("--max_length", type=int, default=None) +parser.add_argument("--mixed_precision", action="store_true") args = parser.parse_args() if args.precision == "fp16": # ipex fp16 @@ -43,7 +44,7 @@ if args.precision == "fp16": # ipex fp16 else: # ipex-llm from ipex_llm.transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(args.model_path, load_in_low_bit=args.precision, - use_cache=args.use_cache, trust_remote_code=True) + use_cache=args.use_cache, trust_remote_code=True, mixed_precision= args.mixed_precision) model = model.half() model = model.to(args.device) model = model.eval()