Fix mixtral-8x7b with transformers=4.37.0 (#11132)

This commit is contained in:
binbin Deng 2024-05-27 09:50:54 +08:00 committed by GitHub
parent ab476c7fe2
commit 367de141f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -414,6 +414,9 @@ def mixtral_model_forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
# to be compatible with transformers>=4.37.0
self._use_flash_attention_2 = self.config._attn_implementation == "flash_attention_2"
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_router_logits = (