From eeecd9fc085754d3d9581ed14cac410683585858 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:21:23 +0800 Subject: [PATCH] Python style fix (#10230) --- .../src/bigdl/llm/transformers/models/yuan.py | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/yuan.py b/python/llm/src/bigdl/llm/transformers/models/yuan.py index b7420d9e..a48015b2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/yuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/yuan.py @@ -53,14 +53,14 @@ def yuan_attention_forward( if use_cache: if is_first_step: if q_len >= 2: - inference_hidden_states_memory = hidden_states[ :, -2:, :] + inference_hidden_states_memory = hidden_states[:, -2:, :] else: inference_hidden_states_memory[:, :, :] = 0 inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :] else: hidden_states_tmp = before_hidden_states[:, -1:, :] - inference_hidden_states_memory = copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), - dim=1)) + inference_hidden_states_memory = \ + copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1)) value_states = \ self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -72,16 +72,17 @@ def yuan_attention_forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) else: - hidden_states = self.lf_gate(hidden_states,before_hidden_states) + hidden_states = self.lf_gate(hidden_states, before_hidden_states) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) qk_states = torch.cat([query_states, key_states], dim=-1) - qk_states = qk_states.view(bsz,q_len,self.num_heads,int(qk_states.shape[-1]//self.num_heads)) - (query_states,key_states) = torch.chunk(qk_states, 2, dim=-1) + qk_states = qk_states.view(bsz, q_len, + self.num_heads, + int(qk_states.shape[-1]//self.num_heads)) + (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] @@ -97,8 +98,9 @@ def yuan_attention_forward( key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) - past_key_value = (key_states, value_states,inference_hidden_states_memory) if use_cache else None - + past_key_value = \ + (key_states, value_states, inference_hidden_states_memory) if use_cache else None + if self.use_flash_attention: attn_weights = None query_states = query_states.transpose(1, 2) @@ -108,20 +110,23 @@ def yuan_attention_forward( batch_size, seqlen_q = query_states.shape[0], query_states.shape[1] seqlen_k = key_states.shape[1] - q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] + q, k, v = \ + [rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] + + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int, + device=q.device) - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int, - device=q.device) - if self.training: invalidInputError(seqlen_k == seqlen_q, - "`seqlen_k` should be equal to `seqlen_q`, but is not") + "`seqlen_k` should be equal to `seqlen_q`, but is not") cu_seqlens_k = cu_seqlens_q is_causal = self.causal_mask else: is_causal = seqlen_q == seqlen_k - cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, - step=seqlen_k, + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, + step=seqlen_k, dtype=torch.int, device=q.device) self.dropout = 0 @@ -150,7 +155,9 @@ def yuan_attention_forward( # upcast attention to fp32 attn_weights = \ - torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + torch.nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),