From d2c821d458d609cbbc131679f35bb2b257312fac Mon Sep 17 00:00:00 2001 From: Song Fuchang Date: Mon, 18 Nov 2024 13:50:18 +0800 Subject: [PATCH] Add missing arguments in pipeline parallel generate method (#12142) Add two arguments: negative_prompt_ids and negative_prompt_attention_mask to generate method in pipeline_parallel.py. --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 87167d81..8a9c0b79 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -245,6 +245,8 @@ def generate( synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ): if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1: @@ -287,6 +289,8 @@ def generate( synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, **kwargs) GenerationMixin.generate = generate