From 30795bdfbca49fff85aa2206d14e3206a4424bbc Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Fri, 23 Feb 2024 10:07:24 +0800 Subject: [PATCH] Gemma optimization: rms_norm, kv_cache, fused_rope, fused_rope+qkv (#10212) * gemma optimization * update * update * fix style * meet code review --- .../llm/src/bigdl/llm/transformers/convert.py | 12 + .../bigdl/llm/transformers/models/gemma.py | 250 ++++++++++++++++++ .../bigdl/llm/transformers/models/utils.py | 4 + 3 files changed, 266 insertions(+) create mode 100644 python/llm/src/bigdl/llm/transformers/models/gemma.py diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index f36cae17..623f409e 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -1062,6 +1062,18 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MistralMLP, llama_mlp_forward) + elif model.config.model_type == "gemma": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.gemma import gemma_attention_forward + from bigdl.llm.transformers.models.gemma import gemma_rms_norm_forward + convert_forward(model, + module.GemmaAttention, + gemma_attention_forward, + ) + convert_forward(model, + module.GemmaRMSNorm, + gemma_rms_norm_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/gemma.py b/python/llm/src/bigdl/llm/transformers/models/gemma.py new file mode 100644 index 00000000..1fcb3cd5 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/gemma.py @@ -0,0 +1,250 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple + +import torch +from torch import nn +from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu +from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half +from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def should_use_fuse_rope(self, hidden_states, position_ids): + use_fuse_rope = hidden_states.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad) + use_fuse_rope = use_fuse_rope and position_ids is not None + return use_fuse_rope + + +def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): + return q_type in [SYM_INT4, FP8E5] and \ + use_fuse_rope and enough_kv_room and bs == 1 + + +def gemma_rms_norm_forward(self, hidden_states): + if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): + import linear_q4_0 + result = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight + 1, + None, + self.eps) + # if nelement == 0, means fused norm failed, go back to python implement. + if result.nelement != 0: + # We should copy this result to avoid by unknown reason on Arc GPUs. + result = result.clone() + return result + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return (1 + self.weight) * hidden_states.to(input_dtype) + + +def gemma_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + position_ids: Optional[torch.LongTensor]=None, + past_key_value: Optional[Tuple[torch.Tensor]]=None, + output_attentions: bool=False, + use_cache: bool=False, + cache_position: Optional[torch.Tensor]=None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, hidden_size = hidden_states.size() + device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + use_fuse_rope, + enough_kv_room, + bsz * q_len) + + if decoding_fast_path: + hidden_states = hidden_states.view(1, -1) + + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] + + kv_seq_len = cache_k.shape[-2] + + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + cache_k, cache_v, + self.q_proj.weight.qtype, + kv_seq_len, + self.head_dim) + kv_seq_len += 1 + + # update past_key_value's seem_tokens and kv caches. + if self.layer_idx == 0: + past_key_value.seen_tokens = kv_seq_len + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + if past_key_value is not None: + if self.layer_idx is None: + invalidInputError(False, + "The cache structure has changed since version v4.36. " + f"If you are using {self.__class__.__name__} for " + "auto-regressive decodingwith k/v caching, please make sure " + "to initialize the attention class with a layer index.") + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if use_fuse_rope: + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "gemma") + else: + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, None) + + if past_key_value is not None: + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value.seen_tokens += key_states.shape[-2] + + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] + + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, + key_states, value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + else: + causal_mask = attention_mask + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError( + False, + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output.to(original_dtype), attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index f14473f9..cbd4caa6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -207,6 +207,10 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) + elif model_family in ["gemma"]: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) else: invalidInputError(False, f"{model_family} is not supported.")