[NPU L0] Update 1st token generation (#12314)
This commit is contained in:
		
							parent
							
								
									d409d9d0eb
								
							
						
					
					
						commit
						f53bb4ea0b
					
				
					 1 changed files with 10 additions and 13 deletions
				
			
		| 
						 | 
				
			
			@ -105,18 +105,12 @@ def generate(
 | 
			
		|||
            key.to(torch.float16).numpy().tofile(os.path.join(temp_dir, f"key_cache_{layer}.bin"))
 | 
			
		||||
            val.to(torch.float16).numpy().tofile(os.path.join(temp_dir, f"value_cache_{layer}.bin"))
 | 
			
		||||
 | 
			
		||||
        token = input_id.to(torch.int32).item()
 | 
			
		||||
        output_tokens.append(torch.tensor([token]))
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            streamer.put(torch.tensor([token]))
 | 
			
		||||
 | 
			
		||||
        if "eos_token_id" not in new_generate_kwargs:
 | 
			
		||||
            eos = 0xffffffff
 | 
			
		||||
        else:
 | 
			
		||||
            eos = new_generate_kwargs["eos_token_id"]
 | 
			
		||||
 | 
			
		||||
        time_t1 = time.perf_counter()
 | 
			
		||||
        idx += 1
 | 
			
		||||
 | 
			
		||||
        # start generate_serve by Thread
 | 
			
		||||
        thread = threading.Thread(target=generate_serve,
 | 
			
		||||
| 
						 | 
				
			
			@ -124,7 +118,7 @@ def generate(
 | 
			
		|||
                                        self.head_dim, self.num_layers,
 | 
			
		||||
                                        self.vocab_size,
 | 
			
		||||
                                        self.transpose_value_cache,
 | 
			
		||||
                                        new_tokens - 2))
 | 
			
		||||
                                        new_tokens - 1))
 | 
			
		||||
        thread.start()
 | 
			
		||||
 | 
			
		||||
        in_pipe_path = "\\\\.\\pipe\\llminputpipe"
 | 
			
		||||
| 
						 | 
				
			
			@ -146,7 +140,7 @@ def generate(
 | 
			
		|||
            else:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        time_start = time.perf_counter()
 | 
			
		||||
        time_t2 = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
        bdata = str.encode(str(temp_dir))
 | 
			
		||||
        invalidInputError(len(bdata) <= 2000,
 | 
			
		||||
| 
						 | 
				
			
			@ -162,6 +156,8 @@ def generate(
 | 
			
		|||
                break
 | 
			
		||||
            token = int.from_bytes(data, sys.byteorder)
 | 
			
		||||
            idx += 1
 | 
			
		||||
            if idx == 1:
 | 
			
		||||
                time_t3 = time.perf_counter()
 | 
			
		||||
            if token == eos:
 | 
			
		||||
                break
 | 
			
		||||
            output_tokens.append(torch.tensor([token]))
 | 
			
		||||
| 
						 | 
				
			
			@ -177,13 +173,14 @@ def generate(
 | 
			
		|||
    time_end = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
    if do_print:
 | 
			
		||||
        print(f" Start the thread and connect the pipe time: {(time_start - time_t1):.2f} s")
 | 
			
		||||
        print(f" Start the thread and connect the pipe time: {(time_t2 - time_t1):.2f} s")
 | 
			
		||||
        print(f" Number of input tokens: {input_length}")
 | 
			
		||||
        print(f" Generated tokens: {idx}")
 | 
			
		||||
        print(f" First token generation time: {(time_t1 - time_start_all):.2f} s")
 | 
			
		||||
        print(f" Generation average latency: {(time_end - time_start) * 1000 /(idx - 1):.2f} ms, "
 | 
			
		||||
              f"({(idx - 1)/(time_end - time_start):.2f} token/s)")
 | 
			
		||||
        print(f" Generation time: {(time_end - time_start_all - (time_start - time_t1)):.2f} s\n")
 | 
			
		||||
        print(" First token generation time: "
 | 
			
		||||
              f"{(time_t3 - time_start_all - (time_t2 - time_t1)):.2f} s")
 | 
			
		||||
        print(f" Generation average latency: {(time_end - time_t3) * 1000 /(idx - 1):.2f} ms, "
 | 
			
		||||
              f"({(idx - 1)/(time_end - time_t3):.2f} token/s)")
 | 
			
		||||
        print(f" Generation time: {(time_end - time_start_all - (time_t2 - time_t1)):.2f} s\n")
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue