[NPU] Support streaming in Python (cpp backend) (#12488)
* Support streaming in NPU Python (cpp backend) * Small fix
This commit is contained in:
		
							parent
							
								
									7082844f3f
								
							
						
					
					
						commit
						4ac66db034
					
				
					 1 changed files with 15 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -314,8 +314,14 @@ def generate(
 | 
			
		|||
            new_generate_kwargs[var] = value
 | 
			
		||||
 | 
			
		||||
    if isinstance(inputs[0], torch.Tensor):
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            # input ids
 | 
			
		||||
            streamer.put(inputs[0])
 | 
			
		||||
        input_list = inputs[0].flatten().tolist()
 | 
			
		||||
    else:
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            # input ids
 | 
			
		||||
            streamer.put(torch.Tensor(inputs[0]))
 | 
			
		||||
        input_list = inputs[0]
 | 
			
		||||
    input_length = len(input_list)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -335,6 +341,9 @@ def generate(
 | 
			
		|||
    from .npu_llm_cpp import run_decode, run_prefill, reset
 | 
			
		||||
 | 
			
		||||
    token = run_prefill(self.model_ptr, input_list, self.vocab_size)
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        # 1st tokens
 | 
			
		||||
        streamer.put(torch.tensor([token]))
 | 
			
		||||
    idx = 1
 | 
			
		||||
    time_t2 = time.perf_counter()
 | 
			
		||||
    output_tokens.append(torch.tensor([token]))
 | 
			
		||||
| 
						 | 
				
			
			@ -342,12 +351,18 @@ def generate(
 | 
			
		|||
        if token == eos:
 | 
			
		||||
            break
 | 
			
		||||
        token = run_decode(self.model_ptr, token, self.vocab_size)
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            # rest tokens
 | 
			
		||||
            streamer.put(torch.tensor([token]))
 | 
			
		||||
        idx += 1
 | 
			
		||||
        output_tokens.append(torch.tensor([token]))
 | 
			
		||||
    output = torch.stack(output_tokens, dim=1)
 | 
			
		||||
    output = torch.cat((inputs, output), dim=1)
 | 
			
		||||
    time_t3 = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        streamer.end()
 | 
			
		||||
 | 
			
		||||
    reset(self.model_ptr)
 | 
			
		||||
    self.first_cost = time_t2 - time_t1  # seconds
 | 
			
		||||
    self.rest_cost_mean = (time_t3 - time_t2) / (idx - 1)  # seconds
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue