Python style fix (#10230)
This commit is contained in:
		
							parent
							
								
									e511bbd8f1
								
							
						
					
					
						commit
						eeecd9fc08
					
				
					 1 changed files with 24 additions and 17 deletions
				
			
		| 
						 | 
				
			
			@ -53,14 +53,14 @@ def yuan_attention_forward(
 | 
			
		|||
    if use_cache:
 | 
			
		||||
        if is_first_step:
 | 
			
		||||
            if q_len >= 2:
 | 
			
		||||
                inference_hidden_states_memory = hidden_states[ :, -2:, :]
 | 
			
		||||
                inference_hidden_states_memory = hidden_states[:, -2:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                inference_hidden_states_memory[:, :, :] = 0
 | 
			
		||||
                inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states_tmp = before_hidden_states[:, -1:, :]
 | 
			
		||||
            inference_hidden_states_memory = copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states),
 | 
			
		||||
                                                                     dim=1))
 | 
			
		||||
            inference_hidden_states_memory = \
 | 
			
		||||
                copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1))
 | 
			
		||||
 | 
			
		||||
    value_states = \
 | 
			
		||||
        self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
| 
						 | 
				
			
			@ -72,16 +72,17 @@ def yuan_attention_forward(
 | 
			
		|||
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    else:
 | 
			
		||||
        hidden_states = self.lf_gate(hidden_states,before_hidden_states)
 | 
			
		||||
        hidden_states = self.lf_gate(hidden_states, before_hidden_states)
 | 
			
		||||
        query_states = self.q_proj(hidden_states)
 | 
			
		||||
        key_states = self.k_proj(hidden_states)
 | 
			
		||||
        qk_states = torch.cat([query_states, key_states], dim=-1)
 | 
			
		||||
        qk_states = qk_states.view(bsz,q_len,self.num_heads,int(qk_states.shape[-1]//self.num_heads))
 | 
			
		||||
        (query_states,key_states) =  torch.chunk(qk_states, 2, dim=-1)
 | 
			
		||||
        qk_states = qk_states.view(bsz, q_len,
 | 
			
		||||
                                   self.num_heads,
 | 
			
		||||
                                   int(qk_states.shape[-1]//self.num_heads))
 | 
			
		||||
        (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
 | 
			
		||||
        query_states = query_states.transpose(1, 2)
 | 
			
		||||
        key_states = key_states.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
| 
						 | 
				
			
			@ -97,8 +98,9 @@ def yuan_attention_forward(
 | 
			
		|||
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
			
		||||
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states,inference_hidden_states_memory) if use_cache else None
 | 
			
		||||
    
 | 
			
		||||
    past_key_value = \
 | 
			
		||||
        (key_states, value_states, inference_hidden_states_memory) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    if self.use_flash_attention:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        query_states = query_states.transpose(1, 2)
 | 
			
		||||
| 
						 | 
				
			
			@ -108,20 +110,23 @@ def yuan_attention_forward(
 | 
			
		|||
        batch_size, seqlen_q = query_states.shape[0], query_states.shape[1]
 | 
			
		||||
        seqlen_k = key_states.shape[1]
 | 
			
		||||
 | 
			
		||||
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
 | 
			
		||||
        q, k, v = \
 | 
			
		||||
            [rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
 | 
			
		||||
 | 
			
		||||
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q,
 | 
			
		||||
                                    step=seqlen_q,
 | 
			
		||||
                                    dtype=torch.int,
 | 
			
		||||
                                    device=q.device)
 | 
			
		||||
 | 
			
		||||
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int,
 | 
			
		||||
                                device=q.device)
 | 
			
		||||
                                
 | 
			
		||||
        if self.training:
 | 
			
		||||
            invalidInputError(seqlen_k == seqlen_q,
 | 
			
		||||
                "`seqlen_k` should be equal to `seqlen_q`, but is not")
 | 
			
		||||
                              "`seqlen_k` should be equal to `seqlen_q`, but is not")
 | 
			
		||||
            cu_seqlens_k = cu_seqlens_q
 | 
			
		||||
            is_causal = self.causal_mask
 | 
			
		||||
        else:
 | 
			
		||||
            is_causal = seqlen_q == seqlen_k
 | 
			
		||||
            cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, 
 | 
			
		||||
                                        step=seqlen_k, 
 | 
			
		||||
            cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k,
 | 
			
		||||
                                        step=seqlen_k,
 | 
			
		||||
                                        dtype=torch.int,
 | 
			
		||||
                                        device=q.device)
 | 
			
		||||
            self.dropout = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -150,7 +155,9 @@ def yuan_attention_forward(
 | 
			
		|||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = \
 | 
			
		||||
            torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
            torch.nn.functional.softmax(attn_weights,
 | 
			
		||||
                                        dim=-1,
 | 
			
		||||
                                        dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
        invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue