Gemma optimization: rms_norm, kv_cache, fused_rope, fused_rope+qkv (#10212)
* gemma optimization * update * update * fix style * meet code review
This commit is contained in:
		
							parent
							
								
									63681af97e
								
							
						
					
					
						commit
						30795bdfbc
					
				
					 3 changed files with 266 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -1062,6 +1062,18 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                convert_forward(model,
 | 
			
		||||
                                module.MistralMLP,
 | 
			
		||||
                                llama_mlp_forward)
 | 
			
		||||
    elif model.config.model_type == "gemma":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.gemma import gemma_attention_forward
 | 
			
		||||
        from bigdl.llm.transformers.models.gemma import gemma_rms_norm_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.GemmaAttention,
 | 
			
		||||
                        gemma_attention_forward,
 | 
			
		||||
                        )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.GemmaRMSNorm,
 | 
			
		||||
                        gemma_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "Yi":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										250
									
								
								python/llm/src/bigdl/llm/transformers/models/gemma.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								python/llm/src/bigdl/llm/transformers/models/gemma.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,250 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 | 
			
		||||
    cos = cos.unsqueeze(unsqueeze_dim)
 | 
			
		||||
    sin = sin.unsqueeze(unsqueeze_dim)
 | 
			
		||||
    q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
			
		||||
    k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
			
		||||
    return q_embed, k_embed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
 | 
			
		||||
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
 | 
			
		||||
    to (batch, num_attention_heads, seqlen, head_dim)
 | 
			
		||||
    """
 | 
			
		||||
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
			
		||||
    if n_rep == 1:
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
			
		||||
                                                           n_rep, slen, head_dim)
 | 
			
		||||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		||||
    use_fuse_rope = hidden_states.device.type == "xpu"
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad)
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and position_ids is not None
 | 
			
		||||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
 | 
			
		||||
    return q_type in [SYM_INT4, FP8E5] and \
 | 
			
		||||
        use_fuse_rope and enough_kv_room and bs == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gemma_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight + 1,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.eps)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
			
		||||
            result = result.clone()
 | 
			
		||||
            return result
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
    hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
			
		||||
    return (1 + self.weight) * hidden_states.to(input_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gemma_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor]=None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor]=None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]]=None,
 | 
			
		||||
    output_attentions: bool=False,
 | 
			
		||||
    use_cache: bool=False,
 | 
			
		||||
    cache_position: Optional[torch.Tensor]=None,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    bsz, q_len, hidden_size = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
 | 
			
		||||
                                                use_fuse_rope,
 | 
			
		||||
                                                enough_kv_room,
 | 
			
		||||
                                                bsz * q_len)
 | 
			
		||||
 | 
			
		||||
    if decoding_fast_path:
 | 
			
		||||
        hidden_states = hidden_states.view(1, -1)
 | 
			
		||||
 | 
			
		||||
        cache_k = past_key_value.key_cache[self.layer_idx]
 | 
			
		||||
        cache_v = past_key_value.value_cache[self.layer_idx]
 | 
			
		||||
 | 
			
		||||
        kv_seq_len = cache_k.shape[-2]
 | 
			
		||||
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
 | 
			
		||||
                                                                         self.q_proj.weight,
 | 
			
		||||
                                                                         self.k_proj.weight,
 | 
			
		||||
                                                                         self.v_proj.weight,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         cache_k, cache_v,
 | 
			
		||||
                                                                         self.q_proj.weight.qtype,
 | 
			
		||||
                                                                         kv_seq_len,
 | 
			
		||||
                                                                         self.head_dim)
 | 
			
		||||
        kv_seq_len += 1
 | 
			
		||||
 | 
			
		||||
        # update past_key_value's seem_tokens and kv caches.
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
            past_key_value.seen_tokens = kv_seq_len
 | 
			
		||||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        query_states = self.q_proj(hidden_states)
 | 
			
		||||
        key_states = self.k_proj(hidden_states)
 | 
			
		||||
        value_states = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
        key_states = key_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
        value_states = value_states.view(bsz, q_len,
 | 
			
		||||
                                         self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        kv_seq_len = key_states.shape[-2]
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            if self.layer_idx is None:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "The cache structure has changed since version v4.36. "
 | 
			
		||||
                                  f"If you are using {self.__class__.__name__} for "
 | 
			
		||||
                                  "auto-regressive decodingwith k/v caching, please make sure "
 | 
			
		||||
                                  "to initialize the attention class with a layer index.")
 | 
			
		||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
 | 
			
		||||
                                                                           sin, cos, "gemma")
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                            cos, sin, None)
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            # update the number of seen tokens
 | 
			
		||||
            if self.layer_idx == 0:
 | 
			
		||||
                past_key_value.seen_tokens += key_states.shape[-2]
 | 
			
		||||
 | 
			
		||||
            # reuse k, v, self_attention
 | 
			
		||||
            # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
 | 
			
		||||
            if len(past_key_value.key_cache) <= self.layer_idx:
 | 
			
		||||
                past_key_value.key_cache.append(key_states)
 | 
			
		||||
                past_key_value.value_cache.append(value_states)
 | 
			
		||||
            else:
 | 
			
		||||
                cache_k = past_key_value.key_cache[self.layer_idx]
 | 
			
		||||
                cache_v = past_key_value.value_cache[self.layer_idx]
 | 
			
		||||
 | 
			
		||||
                if not enough_kv_room:
 | 
			
		||||
                    # allocate new
 | 
			
		||||
                    new_c_k, new_c_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_key_value_heads,  # Support GQA
 | 
			
		||||
                                                       self.head_dim,
 | 
			
		||||
                                                       cache_k.size(2),
 | 
			
		||||
                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                                                       dtype=cache_k.dtype,
 | 
			
		||||
                                                       device=device)
 | 
			
		||||
 | 
			
		||||
                    new_c_k[:] = cache_k
 | 
			
		||||
                    new_c_v[:] = cache_v
 | 
			
		||||
                    cache_k = new_c_k
 | 
			
		||||
                    cache_v = new_c_v
 | 
			
		||||
 | 
			
		||||
                key_states, value_states = append_kv_cache(cache_k, cache_v,
 | 
			
		||||
                                                           key_states, value_states)
 | 
			
		||||
 | 
			
		||||
                # update past_key_value
 | 
			
		||||
                past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
                past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:  # no matter the length, we just slice it
 | 
			
		||||
        if cache_position is not None:
 | 
			
		||||
            causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
 | 
			
		||||
        else:
 | 
			
		||||
            causal_mask = attention_mask
 | 
			
		||||
        attn_weights = attn_weights + causal_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                         dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                         training=self.training)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 | 
			
		||||
            f" {attn_output.size()}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.view(bsz, q_len, -1)
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output.to(original_dtype), attn_weights, past_key_value
 | 
			
		||||
| 
						 | 
				
			
			@ -207,6 +207,10 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i
 | 
			
		|||
        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
			
		||||
        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
 | 
			
		||||
    elif model_family in ["gemma"]:
 | 
			
		||||
        cos = cos.unsqueeze(1)
 | 
			
		||||
        sin = sin.unsqueeze(1)
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"{model_family} is not supported.")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue