From 310f18c8afc4580dc6959bdd9c9ecbd52a5a24b9 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Fri, 11 Oct 2024 17:39:20 +0800 Subject: [PATCH] update NPU pipeline generate (#12182) * update * fix style --- .../npu_pipeline_model/pipeline_model.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py index 25bfd30a..5c1e5b89 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py @@ -59,6 +59,9 @@ def generate( streamer: Optional["BaseStreamer"] = None, **kwargs, ): + # if do_print=True, output timing message + do_print = kwargs.pop("do_print", False) + time_start_all, time_t1, idx = time.perf_counter(), None, 0 new_generate_kwargs = {} for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']: value = kwargs.pop(var, None) @@ -79,7 +82,7 @@ def generate( thread = threading.Thread(target=generate_serve, args=(self.kv_len, self.num_head, self.head_dim, self.num_layers, - new_tokens)) + new_tokens - 1)) thread.start() in_pipe_path = "\\\\.\\pipe\\llminputpipe" @@ -115,6 +118,8 @@ def generate( bdata = bdata + eos.to_bytes(4, sys.byteorder) + time_start = time.perf_counter() + input_pipe.write(bytearray(bdata)) input_pipe.flush() @@ -125,6 +130,9 @@ def generate( if len(data) == 0: break token = int.from_bytes(data, sys.byteorder) + idx += 1 + if time_t1 is None: + time_t1 = time.perf_counter() output_tokens.append(torch.tensor([token])) if streamer is not None: streamer.put(torch.tensor([token])) @@ -132,10 +140,22 @@ def generate( break output = torch.stack(output_tokens, dim=1) + output = torch.cat((inputs, output), dim=1) if streamer is not None: streamer.end() thread.join() + time_end = time.perf_counter() + + if do_print: + print(f" Start the thread and connect the pipe time: {(time_start - time_start_all):.2f} s") + print(f" Number of input tokens: {input_length}") + print(f" Generated tokens: {idx}") + print(f" First token generation time: {(time_t1 - time_start):.2f} s") + print(f" Generation average latency: {(time_end - time_t1)*1000 /(idx - 1):.2f} ms, " + f"({(idx - 1)/(time_end - time_t1):.2f} token/s)") + print(f" Generation time: {(time_end - time_start):.2f} s\n") + return output