Add input padding during prefill for qwen2-7b (#12033)
This commit is contained in:
parent
f61b1785fb
commit
d2e1b9aaff
1 changed files with 15 additions and 2 deletions
|
|
@ -934,8 +934,21 @@ class PrefillRunner:
|
|||
" to max_prompt_len {self.max_prompt_len}"
|
||||
),
|
||||
)
|
||||
self.prefill_input_queue.put((hidden_states, position_ids, attention_mask, past_key_value))
|
||||
return self.prefill_result_queue.get()
|
||||
pad_len = self.max_prompt_len - seq_len
|
||||
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
|
||||
position_ids = F.pad(position_ids, (0, pad_len), value=0)
|
||||
attention_mask = F.pad(
|
||||
attention_mask.to(torch.float16),
|
||||
(0, pad_len, 0, pad_len),
|
||||
value=torch.finfo(torch.float16).min,
|
||||
)
|
||||
|
||||
args = (hidden_states, position_ids, attention_mask, past_key_value)
|
||||
self.prefill_input_queue.put(args)
|
||||
hidden_states, past_key_value = self.prefill_result_queue.get()
|
||||
past_key_value.shrink(seq_len, self.transpose_value_cache)
|
||||
hidden_states = hidden_states[:, :seq_len, :]
|
||||
return hidden_states, past_key_value
|
||||
|
||||
def shutdown(self):
|
||||
self.prefill_input_queue.put("stop")
|
||||
|
|
|
|||
Loading…
Reference in a new issue