Add input padding during prefill for qwen2-7b (#12033)

This commit is contained in:
binbin Deng 2024-09-06 16:39:59 +08:00 committed by GitHub
parent f61b1785fb
commit d2e1b9aaff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -934,8 +934,21 @@ class PrefillRunner:
" to max_prompt_len {self.max_prompt_len}"
),
)
self.prefill_input_queue.put((hidden_states, position_ids, attention_mask, past_key_value))
return self.prefill_result_queue.get()
pad_len = self.max_prompt_len - seq_len
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad(
attention_mask.to(torch.float16),
(0, pad_len, 0, pad_len),
value=torch.finfo(torch.float16).min,
)
args = (hidden_states, position_ids, attention_mask, past_key_value)
self.prefill_input_queue.put(args)
hidden_states, past_key_value = self.prefill_result_queue.get()
past_key_value.shrink(seq_len, self.transpose_value_cache)
hidden_states = hidden_states[:, :seq_len, :]
return hidden_states, past_key_value
def shutdown(self):
self.prefill_input_queue.put("stop")