Enable fastchat benchmark latency (#11017)
* enable fastchat benchmark * add readme * update readme * update
This commit is contained in:
parent
93d40ab127
commit
2084ebe4ee
1 changed files with 9 additions and 0 deletions
|
|
@ -71,6 +71,7 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
speculative: bool = False,
|
speculative: bool = False,
|
||||||
load_low_bit_model: bool = False,
|
load_low_bit_model: bool = False,
|
||||||
stream_interval: int = 4,
|
stream_interval: int = 4,
|
||||||
|
benchmark: str = "true",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
controller_addr,
|
controller_addr,
|
||||||
|
|
@ -103,6 +104,10 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
speculative,
|
speculative,
|
||||||
load_low_bit_model,
|
load_low_bit_model,
|
||||||
)
|
)
|
||||||
|
if benchmark.lower() == "true":
|
||||||
|
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
|
||||||
|
self.model = BenchmarkWrapper(self.model, do_print=True)
|
||||||
|
logger.info(f"enable benchmark successfully")
|
||||||
self.stream_interval = stream_interval
|
self.stream_interval = stream_interval
|
||||||
self.context_len = get_context_length(self.model.config)
|
self.context_len = get_context_length(self.model.config)
|
||||||
self.embed_in_truncate = embed_in_truncate
|
self.embed_in_truncate = embed_in_truncate
|
||||||
|
|
@ -495,6 +500,9 @@ if __name__ == "__main__":
|
||||||
default=False,
|
default=False,
|
||||||
help="To use self-speculative or not",
|
help="To use self-speculative or not",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--benchmark", type=str, default="true", help="To print model generation latency or not"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -527,5 +535,6 @@ if __name__ == "__main__":
|
||||||
args.speculative,
|
args.speculative,
|
||||||
args.load_low_bit_model,
|
args.load_low_bit_model,
|
||||||
args.stream_interval,
|
args.stream_interval,
|
||||||
|
args.benchmark,
|
||||||
)
|
)
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue