[LLM] Split merged_qk to separated q/k linear (#10299)
* modify merge_qk_linear to separated q/k linear * update
This commit is contained in:
parent
f4d7dbcde2
commit
0ab40917fb
2 changed files with 22 additions and 12 deletions
|
|
@ -555,17 +555,27 @@ def _optimize_pre(model):
|
||||||
head_dim = module.head_dim
|
head_dim = module.head_dim
|
||||||
hidden_size = module.hidden_size
|
hidden_size = module.hidden_size
|
||||||
|
|
||||||
merged_qk_proj = torch.nn.Linear(0, 0, False)
|
weight_q = torch.cat([
|
||||||
weight = torch.cat([
|
|
||||||
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||||
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||||
|
], dim=0).view(num_heads * head_dim, hidden_size)
|
||||||
|
|
||||||
|
weight_k = torch.cat([
|
||||||
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||||
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||||
], dim=0).view(num_heads * head_dim * 2, hidden_size)
|
], dim=0).view(num_heads * head_dim, hidden_size)
|
||||||
merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
merged_qk_proj.in_features = hidden_size
|
merged_q_proj = torch.nn.Linear(0, 0, False)
|
||||||
merged_qk_proj.out_features = num_heads * head_dim * 2
|
merged_q_proj.weight = torch.nn.Parameter(weight_q, requires_grad=False)
|
||||||
module.merged_qk_proj = merged_qk_proj
|
merged_q_proj.in_features = hidden_size
|
||||||
|
merged_q_proj.out_features = num_heads * head_dim
|
||||||
|
module.merged_q_proj = merged_q_proj
|
||||||
|
|
||||||
|
merged_k_proj = torch.nn.Linear(0, 0, False)
|
||||||
|
merged_k_proj.weight = torch.nn.Parameter(weight_k, requires_grad=False)
|
||||||
|
merged_k_proj.in_features = hidden_size
|
||||||
|
merged_k_proj.out_features = num_heads * head_dim
|
||||||
|
module.merged_k_proj = merged_k_proj
|
||||||
|
|
||||||
del module.q_proj
|
del module.q_proj
|
||||||
del module.k_proj
|
del module.k_proj
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,7 @@ def yuan_attention_forward(
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if use_quantize_kv_cache(self.merged_qk_proj, hidden_states):
|
if use_quantize_kv_cache(self.merged_q_proj, hidden_states):
|
||||||
forward_function = yuan_attention_forward_quantized
|
forward_function = yuan_attention_forward_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = yuan_attention_forward_origin
|
forward_function = yuan_attention_forward_origin
|
||||||
|
|
@ -206,8 +206,8 @@ def yuan_attention_forward_quantized(
|
||||||
else:
|
else:
|
||||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
||||||
this_hidden_states, hidden_states.dtype)
|
this_hidden_states, hidden_states.dtype)
|
||||||
qk_states = self.merged_qk_proj(hidden_states)
|
query_states = self.merged_q_proj(hidden_states)
|
||||||
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
key_states = self.merged_k_proj(hidden_states)
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
|
@ -360,8 +360,8 @@ def yuan_attention_forward_origin(
|
||||||
else:
|
else:
|
||||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
||||||
this_hidden_states, hidden_states.dtype)
|
this_hidden_states, hidden_states.dtype)
|
||||||
qk_states = self.merged_qk_proj(hidden_states)
|
query_states = self.merged_q_proj(hidden_states)
|
||||||
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
key_states = self.merged_k_proj(hidden_states)
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue