diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index b773c424..f1e160a1 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -638,7 +638,7 @@ def llama_attention_forward_4_36( "Please make sure use `attention_mask` instead.`" ) - bsz, q_len, _ = hidden_states.size() + bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device # for flash attention original_dtype = hidden_states.dtype