diff --git a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py index dd859f42..c491be4b 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py +++ b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py @@ -71,6 +71,7 @@ class BigDLLLMWorker(BaseModelWorker): speculative: bool = False, load_low_bit_model: bool = False, stream_interval: int = 4, + benchmark: str = "true", ): super().__init__( controller_addr, @@ -103,6 +104,10 @@ class BigDLLLMWorker(BaseModelWorker): speculative, load_low_bit_model, ) + if benchmark.lower() == "true": + from ipex_llm.utils.benchmark_util import BenchmarkWrapper + self.model = BenchmarkWrapper(self.model, do_print=True) + logger.info(f"enable benchmark successfully") self.stream_interval = stream_interval self.context_len = get_context_length(self.model.config) self.embed_in_truncate = embed_in_truncate @@ -495,6 +500,9 @@ if __name__ == "__main__": default=False, help="To use self-speculative or not", ) + parser.add_argument( + "--benchmark", type=str, default="true", help="To print model generation latency or not" + ) parser.add_argument( "--trust-remote-code", action="store_true", @@ -527,5 +535,6 @@ if __name__ == "__main__": args.speculative, args.load_low_bit_model, args.stream_interval, + args.benchmark, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info")