support new baichuan model (#10404)
This commit is contained in:
parent
a90e9b6ec2
commit
06a851afa9
2 changed files with 5 additions and 5 deletions
|
|
@ -953,7 +953,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
|
||||
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
|
||||
# baichuan2
|
||||
if model.config.hidden_size == 4096:
|
||||
if model.config.hidden_size in [4096, 2048]:
|
||||
# baichuan2-7B
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ def baichuan_attention_forward_7b_quantized(
|
|||
device = hidden_states.device
|
||||
|
||||
proj = self.W_pack(hidden_states)
|
||||
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||
proj = torch.chunk(proj, 3, -1)
|
||||
# batch_size x source_len x hidden_size
|
||||
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# batch_size x target_len x head_size
|
||||
|
|
@ -176,7 +176,7 @@ def baichuan_attention_forward_7b_quantized(
|
|||
value_states.transpose(-1, -2))
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
|
|
@ -198,7 +198,7 @@ def baichuan_attention_forward_7b_origin(
|
|||
device = hidden_states.device
|
||||
|
||||
proj = self.W_pack(hidden_states)
|
||||
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||
proj = torch.chunk(proj, 3, -1)
|
||||
# batch_size x source_len x hidden_size
|
||||
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# batch_size x target_len x head_size
|
||||
|
|
@ -283,7 +283,7 @@ def baichuan_attention_forward_7b_origin(
|
|||
attn_output = torch.matmul(attn_output, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
|
|
|
|||
Loading…
Reference in a new issue