From 0ab40917fbdcb1740b27810425712e1a66554a30 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Fri, 1 Mar 2024 16:48:55 +0800 Subject: [PATCH] [LLM] Split merged_qk to separated q/k linear (#10299) * modify merge_qk_linear to separated q/k linear * update --- .../llm/src/bigdl/llm/transformers/convert.py | 24 +++++++++++++------ .../src/bigdl/llm/transformers/models/yuan.py | 10 ++++---- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2a3b781d..4e7c3807 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -555,17 +555,27 @@ def _optimize_pre(model): head_dim = module.head_dim hidden_size = module.hidden_size - merged_qk_proj = torch.nn.Linear(0, 0, False) - weight = torch.cat([ + weight_q = torch.cat([ q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :], k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :], + ], dim=0).view(num_heads * head_dim, hidden_size) + + weight_k = torch.cat([ q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :], k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :], - ], dim=0).view(num_heads * head_dim * 2, hidden_size) - merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False) - merged_qk_proj.in_features = hidden_size - merged_qk_proj.out_features = num_heads * head_dim * 2 - module.merged_qk_proj = merged_qk_proj + ], dim=0).view(num_heads * head_dim, hidden_size) + + merged_q_proj = torch.nn.Linear(0, 0, False) + merged_q_proj.weight = torch.nn.Parameter(weight_q, requires_grad=False) + merged_q_proj.in_features = hidden_size + merged_q_proj.out_features = num_heads * head_dim + module.merged_q_proj = merged_q_proj + + merged_k_proj = torch.nn.Linear(0, 0, False) + merged_k_proj.weight = torch.nn.Parameter(weight_k, requires_grad=False) + merged_k_proj.in_features = hidden_size + merged_k_proj.out_features = num_heads * head_dim + module.merged_k_proj = merged_k_proj del module.q_proj del module.k_proj diff --git a/python/llm/src/bigdl/llm/transformers/models/yuan.py b/python/llm/src/bigdl/llm/transformers/models/yuan.py index 8dc4f990..9dc39be6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/yuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/yuan.py @@ -147,7 +147,7 @@ def yuan_attention_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.merged_qk_proj, hidden_states): + if use_quantize_kv_cache(self.merged_q_proj, hidden_states): forward_function = yuan_attention_forward_quantized else: forward_function = yuan_attention_forward_origin @@ -206,8 +206,8 @@ def yuan_attention_forward_quantized( else: hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states, this_hidden_states, hidden_states.dtype) - qk_states = self.merged_qk_proj(hidden_states) - (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1) + query_states = self.merged_q_proj(hidden_states) + key_states = self.merged_k_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -360,8 +360,8 @@ def yuan_attention_forward_origin( else: hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states, this_hidden_states, hidden_states.dtype) - qk_states = self.merged_qk_proj(hidden_states) - (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1) + query_states = self.merged_q_proj(hidden_states) + key_states = self.merged_k_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)