optimize yuan 2.0 again (#10252)
This commit is contained in:
		
							parent
							
								
									03b9c4930a
								
							
						
					
					
						commit
						b4fa4ab46f
					
				
					 2 changed files with 51 additions and 44 deletions
				
			
		| 
						 | 
					@ -1187,7 +1187,6 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
        from bigdl.llm.transformers.models.yuan import yuan_attention_forward
 | 
					        from bigdl.llm.transformers.models.yuan import yuan_attention_forward
 | 
				
			||||||
        from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
 | 
					        from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
 | 
				
			||||||
        from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward
 | 
					 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.YuanAttention,
 | 
					                        module.YuanAttention,
 | 
				
			||||||
                        yuan_attention_forward
 | 
					                        yuan_attention_forward
 | 
				
			||||||
| 
						 | 
					@ -1196,7 +1195,4 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
                        module.YuanMLP,
 | 
					                        module.YuanMLP,
 | 
				
			||||||
                        yuan_mlp_forward
 | 
					                        yuan_mlp_forward
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
        convert_forward(model,
 | 
					 | 
				
			||||||
                        module.LocalizedFiltering,
 | 
					 | 
				
			||||||
                        yuan_localized_filtering_forward)
 | 
					 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -54,22 +54,39 @@ def yuan_localized_filtering_forward(
 | 
				
			||||||
    self,
 | 
					    self,
 | 
				
			||||||
    inputs: torch.Tensor,
 | 
					    inputs: torch.Tensor,
 | 
				
			||||||
    before_hidden_states: torch.Tensor,
 | 
					    before_hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    dtype: torch.dtype,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if self.conv1.weight.dtype != torch.half:
 | 
					    if self.conv1.weight.dtype != torch.half:
 | 
				
			||||||
        self.half()
 | 
					        self.half()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    inputs = inputs.half()
 | 
					 | 
				
			||||||
    if before_hidden_states is not None:
 | 
					 | 
				
			||||||
        before_hidden_states = before_hidden_states.half()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1")
 | 
					    invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1")
 | 
				
			||||||
    if self.training:
 | 
					    invalidInputError(not self.training, ("training is not supported for now, "
 | 
				
			||||||
        lf_output = self._train_forward(inputs)
 | 
					                                          "please call model.eval() before inference"))
 | 
				
			||||||
 | 
					    if before_hidden_states is None:
 | 
				
			||||||
 | 
					        inputs = inputs.half()
 | 
				
			||||||
 | 
					        lf_output = self._inference_forward(inputs, None)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        lf_output = self._inference_forward(inputs, before_hidden_states)
 | 
					        # only change next token logic
 | 
				
			||||||
 | 
					        bsz, seq_len, embed_dim = inputs.size()
 | 
				
			||||||
 | 
					        seq_len_before, _, _ = before_hidden_states.size()
 | 
				
			||||||
 | 
					        invalidInputError(seq_len == 1 and seq_len_before == 3,
 | 
				
			||||||
 | 
					                          f"wrong sequence length: {seq_len} {seq_len_before}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lf_output = lf_output.to(inputs.dtype)
 | 
					        residual = before_hidden_states[-1:, :, :]
 | 
				
			||||||
 | 
					        inputs = before_hidden_states.view(3, 1, bsz, embed_dim).permute(2, 3, 0, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        output1 = self.conv1(inputs)
 | 
				
			||||||
 | 
					        output2 = self.conv2(output1[:, :, 1:-1, :])
 | 
				
			||||||
 | 
					        output2 = output2[:, :, 1:-1, :]
 | 
				
			||||||
 | 
					        output2 = output2.view(1, bsz, embed_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        invalidInputError(output2.shape == residual.shape,
 | 
				
			||||||
 | 
					                          f"wrong shape: {output2.shape} {residual.shape}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lf_output = self.output_layernorm(output2 + residual)
 | 
				
			||||||
 | 
					        lf_output = lf_output.transpose(0, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    lf_output = lf_output.to(dtype)
 | 
				
			||||||
    return lf_output
 | 
					    return lf_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -137,44 +154,38 @@ def yuan_attention_forward(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_cache:
 | 
					    invalidInputError(use_cache, "use_cache=True is needed")
 | 
				
			||||||
        if past_key_value is None:
 | 
					    invalidInputError(not self.use_shareqk, "use_shareqk is not supported for now")
 | 
				
			||||||
            inference_hidden_states_memory = torch.empty(bsz, 2,
 | 
					 | 
				
			||||||
                                                         hidden_states.shape[2],
 | 
					 | 
				
			||||||
                                                         dtype=hidden_states.dtype)
 | 
					 | 
				
			||||||
            is_first_step = True
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            before_hidden_states = past_key_value[2]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_cache:
 | 
					    if past_key_value is None:
 | 
				
			||||||
        if is_first_step:
 | 
					        is_first_step = True
 | 
				
			||||||
            if q_len >= 2:
 | 
					        if q_len >= 2:
 | 
				
			||||||
                inference_hidden_states_memory = hidden_states[:, -2:, :]
 | 
					            before_hidden_states = hidden_states[:, -2:, :].transpose(0, 1).half()
 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                inference_hidden_states_memory[:, :, :] = 0
 | 
					 | 
				
			||||||
                inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            hidden_states_tmp = before_hidden_states[:, -1:, :]
 | 
					            before_hidden_states = torch.zeros(2, bsz, self.hidden_size,
 | 
				
			||||||
            inference_hidden_states_memory = torch.cat((hidden_states_tmp,
 | 
					                                               dtype=torch.half, device=hidden_states.device)
 | 
				
			||||||
                                                        hidden_states), dim=1)
 | 
					            before_hidden_states[-1:, :, :] = hidden_states[:, -1:, :].transpose(0, 1)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        before_hidden_states = past_key_value[2]
 | 
				
			||||||
 | 
					        this_hidden_states = torch.cat([
 | 
				
			||||||
 | 
					            before_hidden_states,
 | 
				
			||||||
 | 
					            hidden_states.transpose(0, 1).half(),
 | 
				
			||||||
 | 
					        ], dim=0)
 | 
				
			||||||
 | 
					        before_hidden_states = this_hidden_states[-2:, :, ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    value_states = \
 | 
					    value_states = \
 | 
				
			||||||
        self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
					        self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if self.use_shareqk:
 | 
					    if is_first_step:
 | 
				
			||||||
        # use_shareqk is disabled for now
 | 
					        hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
 | 
				
			||||||
        qk_states = self.qk_proj(hidden_states).view(bsz, q_len, self.num_heads*self.head_dim)
 | 
					                                                         None, hidden_states.dtype)
 | 
				
			||||||
        query_key = qk_states.unsqueeze(2) * self.qk_weight + self.qk_bias
 | 
					 | 
				
			||||||
        query_states, key_states = torch.unbind(query_key, dim=2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
					 | 
				
			||||||
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        hidden_states = self.lf_gate(hidden_states, before_hidden_states)
 | 
					        hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
 | 
				
			||||||
        qk_states = self.merged_qk_proj(hidden_states)
 | 
					                                                         this_hidden_states, hidden_states.dtype)
 | 
				
			||||||
        (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
 | 
					    qk_states = self.merged_qk_proj(hidden_states)
 | 
				
			||||||
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
					    (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
 | 
				
			||||||
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
					    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    kv_seq_len = key_states.shape[-2]
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
| 
						 | 
					@ -229,7 +240,7 @@ def yuan_attention_forward(
 | 
				
			||||||
        value_states = new_value_states
 | 
					        value_states = new_value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = \
 | 
					    past_key_value = \
 | 
				
			||||||
        (key_states, value_states, inference_hidden_states_memory) if use_cache else None
 | 
					        (key_states, value_states, before_hidden_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = \
 | 
					    attn_weights = \
 | 
				
			||||||
        torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
					        torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue