Python style fix (#10230)
This commit is contained in:
parent
e511bbd8f1
commit
eeecd9fc08
1 changed files with 24 additions and 17 deletions
|
|
@ -59,8 +59,8 @@ def yuan_attention_forward(
|
||||||
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
|
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
|
||||||
else:
|
else:
|
||||||
hidden_states_tmp = before_hidden_states[:, -1:, :]
|
hidden_states_tmp = before_hidden_states[:, -1:, :]
|
||||||
inference_hidden_states_memory = copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states),
|
inference_hidden_states_memory = \
|
||||||
dim=1))
|
copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1))
|
||||||
|
|
||||||
value_states = \
|
value_states = \
|
||||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
@ -76,12 +76,13 @@ def yuan_attention_forward(
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
qk_states = torch.cat([query_states, key_states], dim=-1)
|
qk_states = torch.cat([query_states, key_states], dim=-1)
|
||||||
qk_states = qk_states.view(bsz,q_len,self.num_heads,int(qk_states.shape[-1]//self.num_heads))
|
qk_states = qk_states.view(bsz, q_len,
|
||||||
|
self.num_heads,
|
||||||
|
int(qk_states.shape[-1]//self.num_heads))
|
||||||
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
@ -97,7 +98,8 @@ def yuan_attention_forward(
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states,inference_hidden_states_memory) if use_cache else None
|
past_key_value = \
|
||||||
|
(key_states, value_states, inference_hidden_states_memory) if use_cache else None
|
||||||
|
|
||||||
if self.use_flash_attention:
|
if self.use_flash_attention:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
@ -108,9 +110,12 @@ def yuan_attention_forward(
|
||||||
batch_size, seqlen_q = query_states.shape[0], query_states.shape[1]
|
batch_size, seqlen_q = query_states.shape[0], query_states.shape[1]
|
||||||
seqlen_k = key_states.shape[1]
|
seqlen_k = key_states.shape[1]
|
||||||
|
|
||||||
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
|
q, k, v = \
|
||||||
|
[rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
|
||||||
|
|
||||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int,
|
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int,
|
||||||
device=q.device)
|
device=q.device)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
|
|
@ -150,7 +155,9 @@ def yuan_attention_forward(
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = \
|
attn_weights = \
|
||||||
torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
torch.nn.functional.softmax(attn_weights,
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue