From b8aee7bb1b60852cb45871caa9d79cca49a296bb Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Thu, 12 Oct 2023 15:49:42 +0800 Subject: [PATCH] LLM: Fix Qwen kv_cache optimization (#9148) * first commit * ut pass * accelerate rotate half by using common util function * fix style --- .../src/bigdl/llm/transformers/models/qwen.py | 126 ++++++++---------- 1 file changed, 59 insertions(+), 67 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index ed2c3e51..3a50f9c6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -14,7 +14,7 @@ # limitations under the License. # # Some parts of this file is adapted from -# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py +# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/faf3ff60438d724a7eb78ebed7e2f7c7330c6bd8/modeling_qwen.py # # Copyright (c) Alibaba Cloud. # @@ -37,6 +37,7 @@ except ImportError: rearrange = None from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import rotate_half from bigdl.llm.utils.common import invalidInputError apply_rotary_emb_func = None @@ -48,34 +49,22 @@ logger = logging.get_logger(__name__) KV_CACHE_ALLOC_BLOCK_LENGTH = 256 -def _rotate_half(x): - from einops import rearrange - - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(t, freqs): - if apply_rotary_emb_func is not None: - t_ = t.float() - freqs = freqs.squeeze(0).squeeze(1) - cos = freqs[:, : freqs.shape[-1] // 2].cos() - sin = freqs[:, : freqs.shape[-1] // 2].sin() - output = apply_rotary_emb_func(t_, cos, sin).type_as(t) - return output - else: - rot_dim = freqs.shape[-1] - t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] - t_ = t_.float() - t_pass_ = t_pass_.float() - t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin()) - return torch.cat((t_, t_pass_), dim=-1).type_as(t) + cos, sin = freqs + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] + t_ = t_.float() + t_pass_ = t_pass_.float() + t_ = (t_ * cos) + (rotate_half(t_) * sin) + return torch.cat((t_, t_pass_), dim=-1).type_as(t) def qwen_attention_forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, + registered_causal_mask: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -93,42 +82,33 @@ def qwen_attention_forward( kv_seq_len = hidden_states.size()[1] - if layer_past: - # layer past[0] shape: bs * seq_len * head_num * dim - kv_seq_len += layer_past[0].shape[1] - if ( - self.use_dynamic_ntk - and kv_seq_len == hidden_states.size()[1] - and not self.training - ): - context_value = math.log(kv_seq_len / self.seq_length, 2) + 1 - ntk_alpha = 2 ** math.ceil(context_value) - 1 - ntk_alpha = max(ntk_alpha, 1) - self._ntk_cached = ntk_alpha - else: - ntk_alpha = self._ntk_cached - rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to( - hidden_states.device - ) - - if rotary_pos_emb is not None: - if isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = rotary_pos_emb - else: - rotary_pos_emb = (rotary_pos_emb,) * 2 - - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference + if rotary_pos_emb_list is not None: cur_len = query.shape[1] - q_pos_emb = q_pos_emb[:, -cur_len:, :, :] - k_pos_emb = k_pos_emb[:, -cur_len:, :, :] - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) + if len(rotary_pos_emb_list) == 1: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + else: + query_list = [] + key_list = [] + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] + key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] + query = torch.cat(query_list, dim=0) + key = torch.cat(key_list, dim=0) bsz, _, n_heads, head_dim = key.size() if layer_past is not None: + kv_seq_len += layer_past[0].shape[1] # past_key, past_value = layer_past[0], layer_past[1] # key = torch.cat((past_key, key), dim=1) # value = torch.cat((past_value, value), dim=1) @@ -137,7 +117,7 @@ def qwen_attention_forward( if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, # Support GQA + self.num_heads, self.head_dim, cache_k.size(2), kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, @@ -172,10 +152,12 @@ def qwen_attention_forward( present = None if self.use_logn_attn and not self.training: - if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: - self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] query = query * logn_tensor.expand_as(query) @@ -186,23 +168,33 @@ def qwen_attention_forward( and query.is_cuda ): q, k, v = query, key, value - context_layer = self.core_attention_flash(q, k, v) + context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask) + + # b s h d -> b s (h d) + context_layer = context_layer.flatten(2, 3).contiguous() - context_layer = rearrange( - context_layer, "b s h d -> b s (h d)" - ).contiguous() else: query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + if ( + registered_causal_mask is None + and self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and not query.is_cuda + ): + invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) attn_output, attn_weight = self._attn( - query, key, value, attention_mask, head_mask + query, key, value, registered_causal_mask, attention_mask, head_mask ) context_layer = self._merge_heads( attn_output, self.num_heads, self.head_dim ) attn_output = self.c_proj(context_layer) + outputs = (attn_output, present) if output_attentions: if ( @@ -210,7 +202,7 @@ def qwen_attention_forward( and flash_attn_unpadded_func is not None and not self.is_fp32 ): - invalidInputError("Cannot output attentions while using flash-attn") + invalidInputError(False, "Cannot output attentions while using flash-attn") else: outputs += (attn_weight,)