LLM: add mlp & qkv fusion for FP16 Llama-7B (#9932)

* add mlp fusion for llama

* add mlp fusion

* fix style

* update

* add mm_qkv_out

* fix style

* update

* meet code review

* meet code review
This commit is contained in:
Ruonan Wang 2024-01-26 11:50:38 +08:00 committed by GitHub
parent 98ea3459e5
commit a00efa0564
4 changed files with 154 additions and 10 deletions

View file

@ -567,6 +567,7 @@ def _optimize_post(model, lightweight_bmm=False):
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
from bigdl.llm.transformers.models.llama import llama_mlp_forward
from bigdl.llm.transformers.models.llama import llama_decoder_forward
from transformers.modeling_utils import PreTrainedModel
# All huggingface format models are inherited from `PreTrainedModel`
@ -588,6 +589,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward)
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
llama_decoder_forward)
if version.parse(trans_version) >= version.parse("4.36.0"):
# transformers version >= 4.36.0
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36

View file

@ -44,7 +44,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -104,19 +104,47 @@ def llama_rms_norm_forward(self, hidden_states):
def llama_mlp_forward(
self,
x: torch.Tensor,
residual=None
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
bsz, hidden_size = x_2d.shape
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(linear_q4_0.mlp_forward_xpu(
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
if residual is not None:
return out + residual
else:
return out
elif fp16_fusion_check(self.gate_proj, x, self.training) and \
hidden_size == 4096 and bsz == 1:
hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.gate_proj.weight)
hidden_states = torch.ops.torch_ipex.mm_resmul(
x, self.up_proj.weight, hidden_states1
)
if residual is None:
hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
else:
attn_output = torch.addmm(
residual.flatten(0, -2),
hidden_states.flatten(0, -2),
self.down_proj.weight,
beta=1,
)
hidden_states = attn_output.view(x.shape)
return hidden_states
else:
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
if residual is not None:
return out + residual
else:
return out
def should_use_fuse_rope(self, query_states, position_ids):
@ -136,6 +164,56 @@ def should_use_fast_rope(self, query_states, position_ids):
return use_fuse_rope
def llama_decoder_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37."
"Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# add residual into mlp
hidden_states = self.mlp(hidden_states, residual)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def llama_attention_forward_4_31(
self,
hidden_states: torch.Tensor,
@ -147,7 +225,7 @@ def llama_attention_forward_4_31(
padding_mask: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
@ -202,9 +280,31 @@ def llama_attention_forward_4_31(
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
hidden_size == 4096:
# only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"):
self.qkv_proj_weight = torch.stack([self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight]).contiguous()
self.q_proj.weight.data = self.qkv_proj_weight[0, :, :]
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
torch.xpu.empty_cache()
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.torch_ipex.mm_qkv_out(
hidden_states, self.qkv_proj_weight, None,
query_states, key_states, value_states
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -598,9 +698,31 @@ def llama_attention_forward_4_36(
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
hidden_size == 4096:
# only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"):
self.qkv_proj_weight = torch.stack([self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight]).contiguous()
self.q_proj.weight.data = self.qkv_proj_weight[0, :, :]
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
torch.xpu.empty_cache()
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.torch_ipex.mm_qkv_out(
hidden_states, self.qkv_proj_weight, None,
query_states, key_states, value_states
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)

View file

@ -312,3 +312,19 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
or x.numel() // x.size(-1) == 1
)
)
def fp16_fusion_check(proj, x, training):
# only use fp16 fusion on PVC inference
if proj.qtype != ggml_tensor_qtype["fp16"]:
return False
if proj.weight_type != 2:
return False
if training:
return False
if x.requires_grad:
return False
device_type = get_xpu_device_type(x)
if device_type != "pvc":
return False
return True

View file

@ -166,6 +166,8 @@ def get_ipex_version():
def get_xpu_device_type(x):
if x.device.type != "xpu":
return x.device.type
name = torch.xpu.get_device_name(x.device.index)
if name.startswith("Intel(R) Arc(TM) A"):
return "arc"