diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 5cba7f0e..9a7d5c0d 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -483,7 +483,7 @@ def mistral_attention_forward_original( is_causal=True) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, @@ -492,7 +492,7 @@ def mistral_attention_forward_original( attn_output = attn_output.view(query_states.shape) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) else: attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, @@ -855,7 +855,7 @@ def mistral_attention_forward_4_36_original( is_causal=True) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) else: attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states,