[LLM] Split merged_qk to separated q/k linear (#10299)

* modify merge_qk_linear to separated q/k linear

* update
This commit is contained in:
SONG Ge 2024-03-01 16:48:55 +08:00 committed by GitHub
parent f4d7dbcde2
commit 0ab40917fb
2 changed files with 22 additions and 12 deletions

View file

@ -555,17 +555,27 @@ def _optimize_pre(model):
head_dim = module.head_dim
hidden_size = module.hidden_size
merged_qk_proj = torch.nn.Linear(0, 0, False)
weight = torch.cat([
weight_q = torch.cat([
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
], dim=0).view(num_heads * head_dim, hidden_size)
weight_k = torch.cat([
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
], dim=0).view(num_heads * head_dim * 2, hidden_size)
merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
merged_qk_proj.in_features = hidden_size
merged_qk_proj.out_features = num_heads * head_dim * 2
module.merged_qk_proj = merged_qk_proj
], dim=0).view(num_heads * head_dim, hidden_size)
merged_q_proj = torch.nn.Linear(0, 0, False)
merged_q_proj.weight = torch.nn.Parameter(weight_q, requires_grad=False)
merged_q_proj.in_features = hidden_size
merged_q_proj.out_features = num_heads * head_dim
module.merged_q_proj = merged_q_proj
merged_k_proj = torch.nn.Linear(0, 0, False)
merged_k_proj.weight = torch.nn.Parameter(weight_k, requires_grad=False)
merged_k_proj.in_features = hidden_size
merged_k_proj.out_features = num_heads * head_dim
module.merged_k_proj = merged_k_proj
del module.q_proj
del module.k_proj

View file

@ -147,7 +147,7 @@ def yuan_attention_forward(
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.merged_qk_proj, hidden_states):
if use_quantize_kv_cache(self.merged_q_proj, hidden_states):
forward_function = yuan_attention_forward_quantized
else:
forward_function = yuan_attention_forward_origin
@ -206,8 +206,8 @@ def yuan_attention_forward_quantized(
else:
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
this_hidden_states, hidden_states.dtype)
qk_states = self.merged_qk_proj(hidden_states)
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
query_states = self.merged_q_proj(hidden_states)
key_states = self.merged_k_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -360,8 +360,8 @@ def yuan_attention_forward_origin(
else:
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
this_hidden_states, hidden_states.dtype)
qk_states = self.merged_qk_proj(hidden_states)
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
query_states = self.merged_q_proj(hidden_states)
key_states = self.merged_k_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)