LLM: fix llama 4.36 forward(#10047)

This commit is contained in:
Ruonan Wang 2024-01-31 10:31:10 +08:00 committed by GitHub
parent 53a5140eff
commit 3685622f29

View file

@ -638,7 +638,7 @@ def llama_attention_forward_4_36(
"Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype