support transformers 4.34+ for llama (#9229)

This commit is contained in:
Yang Wang 2023-10-20 13:36:30 +08:00 committed by GitHub
parent 7e96d3e79a
commit 118249b011

View file

@ -79,6 +79,8 @@ def llama_attention_forward_4_31(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device