fusing qkv project and rope (#9612)

* Try fusing qkv project and rope

* add fused mlp

* fuse append cache

* fix style and clean up code

* clean up
This commit is contained in:
Yang Wang 2023-12-18 16:45:00 -08:00 committed by GitHub
parent 4c112ee70c
commit f4fb58d99c
2 changed files with 138 additions and 77 deletions

View file

@ -374,6 +374,7 @@ def _optimize_post(model, lightweight_bmm=False):
from packaging import version from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
from bigdl.llm.transformers.models.llama import llama_mlp_forward
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
# All huggingface format models are inherited from `PreTrainedModel` # All huggingface format models are inherited from `PreTrainedModel`
@ -392,6 +393,9 @@ def _optimize_post(model, lightweight_bmm=False):
model, model,
transformers.models.llama.modeling_llama.LlamaRMSNorm, transformers.models.llama.modeling_llama.LlamaRMSNorm,
llama_rms_norm_forward,) llama_rms_norm_forward,)
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward)
else: else:
# todo implement 4.28.0 ~ 4.30.2 # todo implement 4.28.0 ~ 4.30.2
pass pass

View file

@ -41,6 +41,7 @@ from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -91,6 +92,36 @@ def llama_rms_norm_forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
def llama_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
and not (self.training and x.requires_grad):
import linear_q4_0
x_2d = x.view(-1, x.shape[-1])
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def is_enough_kv_cache_room(past_key_value):
return past_key_value is not None and \
past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
def should_use_fuse_rope(self, query_states, position_ids):
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
use_fuse_rope = use_fuse_rope and position_ids is not None
return use_fuse_rope
def llama_attention_forward_4_31( def llama_attention_forward_4_31(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -115,8 +146,38 @@ def llama_attention_forward_4_31(
else: else:
attention_dtype = original_dtype attention_dtype = original_dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room(past_key_value)
is_q4_0 = self.q_proj.qtype == SYM_INT4
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
# single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding
# and save the key/value states to cache, then return query states and the
# extended key/value cache
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
kv_seq_len = past_key_value[0].shape[-2]
cache_k = past_key_value[0]
cache_v = past_key_value[1]
import linear_q4_0
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
else:
if self.config.pretraining_tp > 1: if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
self.config.pretraining_tp)
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
// self.config.pretraining_tp, dim=0) // self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
@ -150,10 +211,6 @@ def llama_attention_forward_4_31(
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
if use_fuse_rope: if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states, key_states,
@ -168,7 +225,7 @@ def llama_attention_forward_4_31(
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if not enough_kv_room:
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA self.num_key_value_heads, # Support GQA