LLM: add mlp & qkv fusion for FP16 Llama-7B (#9932)
* add mlp fusion for llama * add mlp fusion * fix style * update * add mm_qkv_out * fix style * update * meet code review * meet code review
This commit is contained in:
parent
98ea3459e5
commit
a00efa0564
4 changed files with 154 additions and 10 deletions
|
|
@ -567,6 +567,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
|
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
|
||||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||||
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
||||||
|
from bigdl.llm.transformers.models.llama import llama_decoder_forward
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
# All huggingface format models are inherited from `PreTrainedModel`
|
# All huggingface format models are inherited from `PreTrainedModel`
|
||||||
|
|
@ -588,6 +589,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
transformers.models.llama.modeling_llama.LlamaMLP,
|
transformers.models.llama.modeling_llama.LlamaMLP,
|
||||||
llama_mlp_forward)
|
llama_mlp_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
|
||||||
|
llama_decoder_forward)
|
||||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||||
# transformers version >= 4.36.0
|
# transformers version >= 4.36.0
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
@ -104,19 +104,47 @@ def llama_rms_norm_forward(self, hidden_states):
|
||||||
def llama_mlp_forward(
|
def llama_mlp_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
residual=None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
bsz, hidden_size = x_2d.shape
|
||||||
qtype = getattr(self.gate_proj, "qtype", None)
|
qtype = getattr(self.gate_proj, "qtype", None)
|
||||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
return self.down_proj(linear_q4_0.mlp_forward_xpu(
|
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
qtype
|
qtype
|
||||||
))
|
))
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
if residual is not None:
|
||||||
|
return out + residual
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
elif fp16_fusion_check(self.gate_proj, x, self.training) and \
|
||||||
|
hidden_size == 4096 and bsz == 1:
|
||||||
|
hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.gate_proj.weight)
|
||||||
|
hidden_states = torch.ops.torch_ipex.mm_resmul(
|
||||||
|
x, self.up_proj.weight, hidden_states1
|
||||||
|
)
|
||||||
|
if residual is None:
|
||||||
|
hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
|
||||||
|
else:
|
||||||
|
attn_output = torch.addmm(
|
||||||
|
residual.flatten(0, -2),
|
||||||
|
hidden_states.flatten(0, -2),
|
||||||
|
self.down_proj.weight,
|
||||||
|
beta=1,
|
||||||
|
)
|
||||||
|
hidden_states = attn_output.view(x.shape)
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
if residual is not None:
|
||||||
|
return out + residual
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def should_use_fuse_rope(self, query_states, position_ids):
|
def should_use_fuse_rope(self, query_states, position_ids):
|
||||||
|
|
@ -136,6 +164,56 @@ def should_use_fast_rope(self, query_states, position_ids):
|
||||||
return use_fuse_rope
|
return use_fuse_rope
|
||||||
|
|
||||||
|
|
||||||
|
def llama_decoder_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37."
|
||||||
|
"Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# add residual into mlp
|
||||||
|
hidden_states = self.mlp(hidden_states, residual)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward_4_31(
|
def llama_attention_forward_4_31(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -147,7 +225,7 @@ def llama_attention_forward_4_31(
|
||||||
padding_mask: Optional[torch.LongTensor] = None,
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
@ -202,9 +280,31 @@ def llama_attention_forward_4_31(
|
||||||
for i in range(self.config.pretraining_tp)]
|
for i in range(self.config.pretraining_tp)]
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
else:
|
else:
|
||||||
query_states = self.q_proj(hidden_states)
|
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
|
||||||
key_states = self.k_proj(hidden_states)
|
hidden_size == 4096:
|
||||||
value_states = self.v_proj(hidden_states)
|
# only use mm_qkv_out on pvc for llama-7b
|
||||||
|
if not hasattr(self, "qkv_proj_weight"):
|
||||||
|
self.qkv_proj_weight = torch.stack([self.q_proj.weight,
|
||||||
|
self.k_proj.weight,
|
||||||
|
self.v_proj.weight]).contiguous()
|
||||||
|
self.q_proj.weight.data = self.qkv_proj_weight[0, :, :]
|
||||||
|
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
||||||
|
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
torch.ops.torch_ipex.mm_qkv_out(
|
||||||
|
hidden_states, self.qkv_proj_weight, None,
|
||||||
|
query_states, key_states, value_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len,
|
query_states = query_states.view(bsz, q_len,
|
||||||
self.num_heads, self.head_dim).transpose(1, 2)
|
self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
@ -598,9 +698,31 @@ def llama_attention_forward_4_36(
|
||||||
for i in range(self.config.pretraining_tp)]
|
for i in range(self.config.pretraining_tp)]
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
else:
|
else:
|
||||||
query_states = self.q_proj(hidden_states)
|
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
|
||||||
key_states = self.k_proj(hidden_states)
|
hidden_size == 4096:
|
||||||
value_states = self.v_proj(hidden_states)
|
# only use mm_qkv_out on pvc for llama-7b
|
||||||
|
if not hasattr(self, "qkv_proj_weight"):
|
||||||
|
self.qkv_proj_weight = torch.stack([self.q_proj.weight,
|
||||||
|
self.k_proj.weight,
|
||||||
|
self.v_proj.weight]).contiguous()
|
||||||
|
self.q_proj.weight.data = self.qkv_proj_weight[0, :, :]
|
||||||
|
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
||||||
|
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
torch.ops.torch_ipex.mm_qkv_out(
|
||||||
|
hidden_states, self.qkv_proj_weight, None,
|
||||||
|
query_states, key_states, value_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len,
|
query_states = query_states.view(bsz, q_len,
|
||||||
self.num_heads, self.head_dim).transpose(1, 2)
|
self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
|
||||||
|
|
@ -312,3 +312,19 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
||||||
or x.numel() // x.size(-1) == 1
|
or x.numel() // x.size(-1) == 1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp16_fusion_check(proj, x, training):
|
||||||
|
# only use fp16 fusion on PVC inference
|
||||||
|
if proj.qtype != ggml_tensor_qtype["fp16"]:
|
||||||
|
return False
|
||||||
|
if proj.weight_type != 2:
|
||||||
|
return False
|
||||||
|
if training:
|
||||||
|
return False
|
||||||
|
if x.requires_grad:
|
||||||
|
return False
|
||||||
|
device_type = get_xpu_device_type(x)
|
||||||
|
if device_type != "pvc":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
|
||||||
|
|
@ -166,6 +166,8 @@ def get_ipex_version():
|
||||||
|
|
||||||
|
|
||||||
def get_xpu_device_type(x):
|
def get_xpu_device_type(x):
|
||||||
|
if x.device.type != "xpu":
|
||||||
|
return x.device.type
|
||||||
name = torch.xpu.get_device_name(x.device.index)
|
name = torch.xpu.get_device_name(x.device.index)
|
||||||
if name.startswith("Intel(R) Arc(TM) A"):
|
if name.startswith("Intel(R) Arc(TM) A"):
|
||||||
return "arc"
|
return "arc"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue