From 426660b88ee7fcdc5b6e6cd453467fd206239b8a Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 21 Dec 2023 17:53:29 +0800 Subject: [PATCH] simplify qwen attention (#9747) --- .../src/bigdl/llm/transformers/models/qwen.py | 97 ++++++------------- 1 file changed, 28 insertions(+), 69 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index a5aada9e..d107ac61 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -74,6 +74,9 @@ def qwen_attention_forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ): + invalidInputError(not self.use_flash_attn and not self.use_cache_quantization, + "flash attn and kv_cache quantization are not supported") + mixed_x_layer = self.c_attn(hidden_states) query, key, value = mixed_x_layer.split(self.split_size, dim=2) @@ -119,12 +122,10 @@ def qwen_attention_forward( bsz, _, n_heads, head_dim = key.size() if layer_past is not None: - kv_seq_len += layer_past[0].shape[1] - # past_key, past_value = layer_past[0], layer_past[1] - # key = torch.cat((past_key, key), dim=1) - # value = torch.cat((past_value, value), dim=1) - cache_k = layer_past[0].transpose(1, 2) - cache_v = layer_past[1].transpose(1, 2) + cache_k, cache_v = layer_past[0], layer_past[1] + cache_k = cache_k.transpose(1, 2) + cache_v = cache_v.transpose(1, 2) + kv_seq_len += cache_k.shape[2] if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, @@ -141,8 +142,8 @@ def qwen_attention_forward( key_states, value_states = append_kv_cache(cache_k, cache_v, key.transpose(1, 2), value.transpose(1, 2)) - key = key_states.transpose(1, 2) - value = value_states.transpose(1, 2) + key = key_states + value = value_states elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH new_key_states, new_value_states = init_kv_cache(bsz, @@ -154,80 +155,38 @@ def qwen_attention_forward( device=hidden_states.device) new_key_states[:] = key.transpose(1, 2) new_value_states[:] = value.transpose(1, 2) - key = new_key_states.transpose(1, 2) - value = new_value_states.transpose(1, 2) + key = new_key_states + value = new_value_states - if use_cache: - present = (key, value) - else: - present = None - - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + query_size, key_size = query.size(1), key.size(2) if key_size > self.seq_length and self.use_logn_attn and not self.training: - if self.use_cache_quantization: - seq_start = key[0].size(2) - query.size(1) - seq_end = key[0].size(2) - else: - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) + seq_start = key_size - query_size + seq_end = key_size logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) - - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - and query.is_cuda - ): - q, k, v = query, key, value - attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) + if query_size == key_size: + causal_mask = torch.tril( + torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) + ).view(1, 1, key_size, key_size) else: - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) - if query.size(1) == key_size: - causal_mask = torch.tril( - torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) - ).view(1, 1, key_size, key_size) - else: - causal_mask = None - query = query.permute(0, 2, 1, 3) - if not self.use_cache_quantization: - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - if ( - causal_mask is None - and self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - and not query.is_cuda - ): - invalidOperationError(False, - None, - None, - Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)) + causal_mask = None + query = query.transpose(1, 2) - attn_output, attn_weight = self._attn( - query, key, value, causal_mask, attention_mask, head_mask - ) + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) context_layer = self._merge_heads( attn_output, self.num_heads, self.head_dim ) attn_output = self.c_proj(context_layer) - outputs = (attn_output, present) + if use_cache: + outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2))) + else: + outputs = (attn_output, None) if output_attentions: - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - invalidInputError(False, - f"Cannot output attentions while using flash-attn") - elif not self.use_cache_quantization and SUPPORT_TORCH2: - invalidInputError(False, - f"Cannot output attentions while using scaled_dot_product_attention") - else: - outputs += (attn_weight,) + outputs += (attn_weight,) return outputs