Python style fix (#10230)

This commit is contained in:
Yuwen Hu 2024-02-23 17:21:23 +08:00 committed by GitHub
parent e511bbd8f1
commit eeecd9fc08

View file

@ -53,14 +53,14 @@ def yuan_attention_forward(
if use_cache: if use_cache:
if is_first_step: if is_first_step:
if q_len >= 2: if q_len >= 2:
inference_hidden_states_memory = hidden_states[ :, -2:, :] inference_hidden_states_memory = hidden_states[:, -2:, :]
else: else:
inference_hidden_states_memory[:, :, :] = 0 inference_hidden_states_memory[:, :, :] = 0
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :] inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
else: else:
hidden_states_tmp = before_hidden_states[:, -1:, :] hidden_states_tmp = before_hidden_states[:, -1:, :]
inference_hidden_states_memory = copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), inference_hidden_states_memory = \
dim=1)) copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1))
value_states = \ value_states = \
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -72,16 +72,17 @@ def yuan_attention_forward(
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
else: else:
hidden_states = self.lf_gate(hidden_states,before_hidden_states) hidden_states = self.lf_gate(hidden_states, before_hidden_states)
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
qk_states = torch.cat([query_states, key_states], dim=-1) qk_states = torch.cat([query_states, key_states], dim=-1)
qk_states = qk_states.view(bsz,q_len,self.num_heads,int(qk_states.shape[-1]//self.num_heads)) qk_states = qk_states.view(bsz, q_len,
(query_states,key_states) = torch.chunk(qk_states, 2, dim=-1) self.num_heads,
int(qk_states.shape[-1]//self.num_heads))
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
@ -97,8 +98,9 @@ def yuan_attention_forward(
key_states = torch.cat([past_key_value[0], key_states], dim=2) key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states,inference_hidden_states_memory) if use_cache else None past_key_value = \
(key_states, value_states, inference_hidden_states_memory) if use_cache else None
if self.use_flash_attention: if self.use_flash_attention:
attn_weights = None attn_weights = None
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
@ -108,20 +110,23 @@ def yuan_attention_forward(
batch_size, seqlen_q = query_states.shape[0], query_states.shape[1] batch_size, seqlen_q = query_states.shape[0], query_states.shape[1]
seqlen_k = key_states.shape[1] seqlen_k = key_states.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] q, k, v = \
[rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int,
device=q.device)
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int,
device=q.device)
if self.training: if self.training:
invalidInputError(seqlen_k == seqlen_q, invalidInputError(seqlen_k == seqlen_q,
"`seqlen_k` should be equal to `seqlen_q`, but is not") "`seqlen_k` should be equal to `seqlen_q`, but is not")
cu_seqlens_k = cu_seqlens_q cu_seqlens_k = cu_seqlens_q
is_causal = self.causal_mask is_causal = self.causal_mask
else: else:
is_causal = seqlen_q == seqlen_k is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k,
step=seqlen_k, step=seqlen_k,
dtype=torch.int, dtype=torch.int,
device=q.device) device=q.device)
self.dropout = 0 self.dropout = 0
@ -150,7 +155,9 @@ def yuan_attention_forward(
# upcast attention to fp32 # upcast attention to fp32
attn_weights = \ attn_weights = \
torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) torch.nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),