diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index b5046c64..cc69eacd 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -953,7 +953,7 @@ def _optimize_post(model, lightweight_bmm=False): elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696: # baichuan2 - if model.config.hidden_size == 4096: + if model.config.hidden_size in [4096, 2048]: # baichuan2-7B modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index e5bd784b..1fc8e403 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -118,7 +118,7 @@ def baichuan_attention_forward_7b_quantized( device = hidden_states.device proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + proj = torch.chunk(proj, 3, -1) # batch_size x source_len x hidden_size query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # batch_size x target_len x head_size @@ -176,7 +176,7 @@ def baichuan_attention_forward_7b_quantized( value_states.transpose(-1, -2)) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -198,7 +198,7 @@ def baichuan_attention_forward_7b_origin( device = hidden_states.device proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + proj = torch.chunk(proj, 3, -1) # batch_size x source_len x hidden_size query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # batch_size x target_len x head_size @@ -283,7 +283,7 @@ def baichuan_attention_forward_7b_origin( attn_output = torch.matmul(attn_output, value_states) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: