Add missing arguments in pipeline parallel generate method (#12142)
Add two arguments: negative_prompt_ids and negative_prompt_attention_mask to generate method in pipeline_parallel.py.
This commit is contained in:
parent
3d5fbf2069
commit
d2c821d458
1 changed files with 4 additions and 0 deletions
|
|
@ -245,6 +245,8 @@ def generate(
|
||||||
synced_gpus: Optional[bool] = None,
|
synced_gpus: Optional[bool] = None,
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
|
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
|
||||||
|
|
@ -287,6 +289,8 @@ def generate(
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
assistant_model=assistant_model,
|
assistant_model=assistant_model,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
|
negative_prompt_ids=negative_prompt_ids,
|
||||||
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
GenerationMixin.generate = generate
|
GenerationMixin.generate = generate
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue