fusing qkv project and rope (#9612)

* Try fusing qkv project and rope

* add fused mlp

* fuse append cache

* fix style and clean up code

* clean up
This commit is contained in:
Yang Wang 2023-12-18 16:45:00 -08:00 committed by GitHub
parent 4c112ee70c
commit f4fb58d99c
2 changed files with 138 additions and 77 deletions

View file

@ -374,6 +374,7 @@ def _optimize_post(model, lightweight_bmm=False):
from packaging import version from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
from bigdl.llm.transformers.models.llama import llama_mlp_forward
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
# All huggingface format models are inherited from `PreTrainedModel` # All huggingface format models are inherited from `PreTrainedModel`
@ -392,6 +393,9 @@ def _optimize_post(model, lightweight_bmm=False):
model, model,
transformers.models.llama.modeling_llama.LlamaRMSNorm, transformers.models.llama.modeling_llama.LlamaRMSNorm,
llama_rms_norm_forward,) llama_rms_norm_forward,)
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward)
else: else:
# todo implement 4.28.0 ~ 4.30.2 # todo implement 4.28.0 ~ 4.30.2
pass pass

View file

@ -41,6 +41,7 @@ from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -91,6 +92,36 @@ def llama_rms_norm_forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
def llama_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
and not (self.training and x.requires_grad):
import linear_q4_0
x_2d = x.view(-1, x.shape[-1])
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def is_enough_kv_cache_room(past_key_value):
return past_key_value is not None and \
past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
def should_use_fuse_rope(self, query_states, position_ids):
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
use_fuse_rope = use_fuse_rope and position_ids is not None
return use_fuse_rope
def llama_attention_forward_4_31( def llama_attention_forward_4_31(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -115,88 +146,114 @@ def llama_attention_forward_4_31(
else: else:
attention_dtype = original_dtype attention_dtype = original_dtype
if self.config.pretraining_tp > 1: use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp enough_kv_room = is_enough_kv_cache_room(past_key_value)
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) is_q4_0 = self.q_proj.qtype == SYM_INT4
// self.config.pretraining_tp, dim=0) no_tp = not self.config.pretraining_tp > 1
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) enough_kv_room and bsz * q_len == 1)
query_states = [F.linear(hidden_states, query_slices[i]) # single batch decoding fast path
for i in range(self.config.pretraining_tp)] # forward_qkv takes will perform QKV projection, rotary position embedding
query_states = torch.cat(query_states, dim=-1) # and save the key/value states to cache, then return query states and the
# extended key/value cache
key_states = [F.linear(hidden_states, key_slices[i]) if decoding_fast_path:
for i in range(self.config.pretraining_tp)] hidden_states = hidden_states.view(1, -1)
key_states = torch.cat(key_states, dim=-1) kv_seq_len = past_key_value[0].shape[-2]
value_states = [F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): import linear_q4_0
# allocate new query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
new_cache_k, new_cache_v = extend_kv_cache(bsz, self.q_proj.weight,
self.num_key_value_heads, # Support GQA self.k_proj.weight,
self.head_dim, self.v_proj.weight,
cache_k.size(2), position_ids,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, cache_k, cache_v,
dtype=cache_k.dtype, self.q_proj.weight.qtype,
device=device) kv_seq_len,
new_cache_k[:] = cache_k self.head_dim)
new_cache_v[:] = cache_v kv_seq_len += 1
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) else:
if self.config.pretraining_tp > 1:
key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
self.config.pretraining_tp)
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
// self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
elif use_cache: query_states = [F.linear(hidden_states, query_slices[i])
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH for i in range(self.config.pretraining_tp)]
new_key_states, new_value_states = init_kv_cache(bsz, query_states = torch.cat(query_states, dim=-1)
self.num_key_value_heads,
self.head_dim, key_states = [F.linear(hidden_states, key_slices[i])
kv_seq_len, for i in range(self.config.pretraining_tp)]
max_cache_length, key_states = torch.cat(key_states, dim=-1)
dtype=key_states.dtype,
device=device) value_states = [F.linear(hidden_states, value_slices[i])
new_key_states[:] = key_states for i in range(self.config.pretraining_tp)]
new_value_states[:] = value_states value_states = torch.cat(value_states, dim=-1)
key_states = new_key_states
value_states = new_value_states else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_key_value_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None