Add input padding during prefill for qwen2-7b (#12033)
This commit is contained in:
		
							parent
							
								
									f61b1785fb
								
							
						
					
					
						commit
						d2e1b9aaff
					
				
					 1 changed files with 15 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -934,8 +934,21 @@ class PrefillRunner:
 | 
			
		|||
                " to max_prompt_len {self.max_prompt_len}"
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        self.prefill_input_queue.put((hidden_states, position_ids, attention_mask, past_key_value))
 | 
			
		||||
        return self.prefill_result_queue.get()
 | 
			
		||||
        pad_len = self.max_prompt_len - seq_len
 | 
			
		||||
        hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
 | 
			
		||||
        position_ids = F.pad(position_ids, (0, pad_len), value=0)
 | 
			
		||||
        attention_mask = F.pad(
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            (0, pad_len, 0, pad_len),
 | 
			
		||||
            value=torch.finfo(torch.float16).min,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        args = (hidden_states, position_ids, attention_mask, past_key_value)
 | 
			
		||||
        self.prefill_input_queue.put(args)
 | 
			
		||||
        hidden_states, past_key_value = self.prefill_result_queue.get()
 | 
			
		||||
        past_key_value.shrink(seq_len, self.transpose_value_cache)
 | 
			
		||||
        hidden_states = hidden_states[:, :seq_len, :]
 | 
			
		||||
        return hidden_states, past_key_value
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
        self.prefill_input_queue.put("stop")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue