support new baichuan model (#10404)

This commit is contained in:
Yishuo Wang 2024-03-13 17:45:50 +08:00 committed by GitHub
parent a90e9b6ec2
commit 06a851afa9
2 changed files with 5 additions and 5 deletions

View file

@ -953,7 +953,7 @@ def _optimize_post(model, lightweight_bmm=False):
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# baichuan2
if model.config.hidden_size == 4096:
if model.config.hidden_size in [4096, 2048]:
# baichuan2-7B
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -118,7 +118,7 @@ def baichuan_attention_forward_7b_quantized(
device = hidden_states.device
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
proj = torch.chunk(proj, 3, -1)
# batch_size x source_len x hidden_size
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x target_len x head_size
@ -176,7 +176,7 @@ def baichuan_attention_forward_7b_quantized(
value_states.transpose(-1, -2))
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -198,7 +198,7 @@ def baichuan_attention_forward_7b_origin(
device = hidden_states.device
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
proj = torch.chunk(proj, 3, -1)
# batch_size x source_len x hidden_size
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x target_len x head_size
@ -283,7 +283,7 @@ def baichuan_attention_forward_7b_origin(
attn_output = torch.matmul(attn_output, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions: