From 0a3e4e788fea1b7c2bfa0572e42a74d19defe6a7 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Tue, 26 Mar 2024 10:55:44 +0800 Subject: [PATCH] LLM: fix mistral hidden_size setting for deepspeed autotp (#10527) --- python/llm/src/ipex_llm/transformers/models/mistral.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 5cba7f0e..9a7d5c0d 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -483,7 +483,7 @@ def mistral_attention_forward_original( is_causal=True) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, @@ -492,7 +492,7 @@ def mistral_attention_forward_original( attn_output = attn_output.view(query_states.shape) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) else: attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, @@ -855,7 +855,7 @@ def mistral_attention_forward_4_36_original( is_causal=True) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) else: attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states,