diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index dc896236..ad639f67 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -934,8 +934,21 @@ class PrefillRunner: " to max_prompt_len {self.max_prompt_len}" ), ) - self.prefill_input_queue.put((hidden_states, position_ids, attention_mask, past_key_value)) - return self.prefill_result_queue.get() + pad_len = self.max_prompt_len - seq_len + hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) + position_ids = F.pad(position_ids, (0, pad_len), value=0) + attention_mask = F.pad( + attention_mask.to(torch.float16), + (0, pad_len, 0, pad_len), + value=torch.finfo(torch.float16).min, + ) + + args = (hidden_states, position_ids, attention_mask, past_key_value) + self.prefill_input_queue.put(args) + hidden_states, past_key_value = self.prefill_result_queue.get() + past_key_value.shrink(seq_len, self.transpose_value_cache) + hidden_states = hidden_states[:, :seq_len, :] + return hidden_states, past_key_value def shutdown(self): self.prefill_input_queue.put("stop")