Add missing arguments in pipeline parallel generate method (#12142)

Add two arguments: negative_prompt_ids and negative_prompt_attention_mask to generate method in pipeline_parallel.py.
This commit is contained in:
Song Fuchang 2024-11-18 13:50:18 +08:00 committed by GitHub
parent 3d5fbf2069
commit d2c821d458
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -245,6 +245,8 @@ def generate(
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
@ -287,6 +289,8 @@ def generate(
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**kwargs)
GenerationMixin.generate = generate