Add missing arguments in pipeline parallel generate method (#12142)
Add two arguments: negative_prompt_ids and negative_prompt_attention_mask to generate method in pipeline_parallel.py.
This commit is contained in:
parent
3d5fbf2069
commit
d2c821d458
1 changed files with 4 additions and 0 deletions
|
|
@ -245,6 +245,8 @@ def generate(
|
|||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
|
||||
|
|
@ -287,6 +289,8 @@ def generate(
|
|||
synced_gpus=synced_gpus,
|
||||
assistant_model=assistant_model,
|
||||
streamer=streamer,
|
||||
negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
**kwargs)
|
||||
|
||||
GenerationMixin.generate = generate
|
||||
|
|
|
|||
Loading…
Reference in a new issue