Fix mixtral-8x7b with transformers=4.37.0 (#11132)
This commit is contained in:
parent
ab476c7fe2
commit
367de141f2
1 changed files with 3 additions and 0 deletions
|
|
@ -414,6 +414,9 @@ def mixtral_model_forward(
|
||||||
output_router_logits: Optional[bool] = None,
|
output_router_logits: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||||
|
# to be compatible with transformers>=4.37.0
|
||||||
|
self._use_flash_attention_2 = self.config._attn_implementation == "flash_attention_2"
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None \
|
output_attentions = output_attentions if output_attentions is not None \
|
||||||
else self.config.output_attentions
|
else self.config.output_attentions
|
||||||
output_router_logits = (
|
output_router_logits = (
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue