diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 6a7c9099..4225dc68 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -1187,7 +1187,6 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.yuan import yuan_attention_forward from bigdl.llm.transformers.models.yuan import yuan_mlp_forward - from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward convert_forward(model, module.YuanAttention, yuan_attention_forward @@ -1196,7 +1195,4 @@ def _optimize_post(model, lightweight_bmm=False): module.YuanMLP, yuan_mlp_forward ) - convert_forward(model, - module.LocalizedFiltering, - yuan_localized_filtering_forward) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/yuan.py b/python/llm/src/bigdl/llm/transformers/models/yuan.py index a874925f..5c07b9a7 100644 --- a/python/llm/src/bigdl/llm/transformers/models/yuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/yuan.py @@ -54,22 +54,39 @@ def yuan_localized_filtering_forward( self, inputs: torch.Tensor, before_hidden_states: torch.Tensor, + dtype: torch.dtype, ): if self.conv1.weight.dtype != torch.half: self.half() - inputs = inputs.half() - if before_hidden_states is not None: - before_hidden_states = before_hidden_states.half() - invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1") - if self.training: - lf_output = self._train_forward(inputs) + invalidInputError(not self.training, ("training is not supported for now, " + "please call model.eval() before inference")) + if before_hidden_states is None: + inputs = inputs.half() + lf_output = self._inference_forward(inputs, None) else: - lf_output = self._inference_forward(inputs, before_hidden_states) + # only change next token logic + bsz, seq_len, embed_dim = inputs.size() + seq_len_before, _, _ = before_hidden_states.size() + invalidInputError(seq_len == 1 and seq_len_before == 3, + f"wrong sequence length: {seq_len} {seq_len_before}") - lf_output = lf_output.to(inputs.dtype) + residual = before_hidden_states[-1:, :, :] + inputs = before_hidden_states.view(3, 1, bsz, embed_dim).permute(2, 3, 0, 1) + output1 = self.conv1(inputs) + output2 = self.conv2(output1[:, :, 1:-1, :]) + output2 = output2[:, :, 1:-1, :] + output2 = output2.view(1, bsz, embed_dim) + + invalidInputError(output2.shape == residual.shape, + f"wrong shape: {output2.shape} {residual.shape}") + + lf_output = self.output_layernorm(output2 + residual) + lf_output = lf_output.transpose(0, 1) + + lf_output = lf_output.to(dtype) return lf_output @@ -137,44 +154,38 @@ def yuan_attention_forward( enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) - if use_cache: - if past_key_value is None: - inference_hidden_states_memory = torch.empty(bsz, 2, - hidden_states.shape[2], - dtype=hidden_states.dtype) - is_first_step = True - else: - before_hidden_states = past_key_value[2] + invalidInputError(use_cache, "use_cache=True is needed") + invalidInputError(not self.use_shareqk, "use_shareqk is not supported for now") - if use_cache: - if is_first_step: - if q_len >= 2: - inference_hidden_states_memory = hidden_states[:, -2:, :] - else: - inference_hidden_states_memory[:, :, :] = 0 - inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :] + if past_key_value is None: + is_first_step = True + if q_len >= 2: + before_hidden_states = hidden_states[:, -2:, :].transpose(0, 1).half() else: - hidden_states_tmp = before_hidden_states[:, -1:, :] - inference_hidden_states_memory = torch.cat((hidden_states_tmp, - hidden_states), dim=1) + before_hidden_states = torch.zeros(2, bsz, self.hidden_size, + dtype=torch.half, device=hidden_states.device) + before_hidden_states[-1:, :, :] = hidden_states[:, -1:, :].transpose(0, 1) + else: + before_hidden_states = past_key_value[2] + this_hidden_states = torch.cat([ + before_hidden_states, + hidden_states.transpose(0, 1).half(), + ], dim=0) + before_hidden_states = this_hidden_states[-2:, :, ] value_states = \ self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - if self.use_shareqk: - # use_shareqk is disabled for now - qk_states = self.qk_proj(hidden_states).view(bsz, q_len, self.num_heads*self.head_dim) - query_key = qk_states.unsqueeze(2) * self.qk_weight + self.qk_bias - query_states, key_states = torch.unbind(query_key, dim=2) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + if is_first_step: + hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states, + None, hidden_states.dtype) else: - hidden_states = self.lf_gate(hidden_states, before_hidden_states) - qk_states = self.merged_qk_proj(hidden_states) - (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states, + this_hidden_states, hidden_states.dtype) + qk_states = self.merged_qk_proj(hidden_states) + (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -229,7 +240,7 @@ def yuan_attention_forward( value_states = new_value_states past_key_value = \ - (key_states, value_states, inference_hidden_states_memory) if use_cache else None + (key_states, value_states, before_hidden_states) if use_cache else None attn_weights = \ torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)