Enable Gemma fused mlp + Gelu (#10276)
* update llama mlp forward * add all * fix style check * split * update * update * update * fix style
This commit is contained in:
parent
2d930bdca8
commit
232273a1b5
8 changed files with 45 additions and 9 deletions
|
|
@ -1108,6 +1108,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.gemma import gemma_attention_forward
|
from bigdl.llm.transformers.models.gemma import gemma_attention_forward
|
||||||
from bigdl.llm.transformers.models.gemma import gemma_rms_norm_forward
|
from bigdl.llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||||
|
from bigdl.llm.transformers.models.gemma import gemma_mlp_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.GemmaAttention,
|
module.GemmaAttention,
|
||||||
gemma_attention_forward,
|
gemma_attention_forward,
|
||||||
|
|
@ -1115,6 +1116,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.GemmaRMSNorm,
|
module.GemmaRMSNorm,
|
||||||
gemma_rms_norm_forward)
|
gemma_rms_norm_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.GemmaMLP,
|
||||||
|
gemma_mlp_forward)
|
||||||
elif model.config.model_type == "Yi":
|
elif model.config.model_type == "Yi":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, SILU
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
@ -80,7 +80,7 @@ def baichuan_mlp_forward(
|
||||||
return self.down_proj(linear_q4_0.mlp_forward_xpu(
|
return self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
qtype
|
SILU, qtype
|
||||||
))
|
))
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ from torch import nn
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check, GELU
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||||
|
|
||||||
|
|
@ -98,6 +99,31 @@ def gemma_rms_norm_forward(self, hidden_states):
|
||||||
return (1 + self.weight) * hidden_states.to(input_dtype)
|
return (1 + self.weight) * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_mlp_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual=None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
bsz, hidden_size = x_2d.shape
|
||||||
|
qtype = getattr(self.gate_proj, "qtype", None)
|
||||||
|
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
||||||
|
import linear_q4_0
|
||||||
|
if not x_2d.is_contiguous():
|
||||||
|
x_2d = x_2d.contiguous()
|
||||||
|
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
|
GELU, qtype
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
if residual is not None:
|
||||||
|
return out + residual
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def gemma_attention_forward(
|
def gemma_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -136,6 +162,7 @@ def gemma_attention_forward(
|
||||||
position_ids,
|
position_ids,
|
||||||
cache_k, cache_v,
|
cache_k, cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ import math
|
||||||
import os
|
import os
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import SILU
|
||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
|
|
@ -118,7 +119,7 @@ def llama_mlp_forward(
|
||||||
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
|
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
qtype
|
SILU, qtype
|
||||||
))
|
))
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
return out + residual
|
return out + residual
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||||
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
|
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from bigdl.llm.transformers.low_bit_linear import IQ2_XXS
|
from bigdl.llm.transformers.low_bit_linear import IQ2_XXS
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -371,7 +371,7 @@ def mixtral_mlp_forward(
|
||||||
return self.w2(linear_q4_0.mlp_forward_xpu(
|
return self.w2(linear_q4_0.mlp_forward_xpu(
|
||||||
x, self.w1.weight.data, self.w3.weight.data,
|
x, self.w1.weight.data, self.w3.weight.data,
|
||||||
x.shape[0], x.shape[1], self.w1.out_len,
|
x.shape[0], x.shape[1], self.w1.out_len,
|
||||||
qtype,
|
SILU, qtype,
|
||||||
)) * routing_weights
|
)) * routing_weights
|
||||||
else:
|
else:
|
||||||
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ except ImportError:
|
||||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half
|
from bigdl.llm.transformers.models.utils import rotate_half, SILU
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||||
|
|
@ -292,6 +292,6 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.c_proj(linear_q4_0.mlp_forward_xpu(
|
return self.c_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.w2.weight.data, self.w1.weight.data,
|
x_2d, self.w2.weight.data, self.w1.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
|
||||||
qtype
|
SILU, qtype
|
||||||
))
|
))
|
||||||
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,10 @@ SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
||||||
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
|
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
|
||||||
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
|
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
|
||||||
|
|
||||||
|
# used in fused mlp forward
|
||||||
|
SILU = 0
|
||||||
|
GELU = 1
|
||||||
|
|
||||||
|
|
||||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31. SILU
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -107,7 +107,7 @@ def yuan_mlp_forward(
|
||||||
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
|
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
|
x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
|
||||||
qtype
|
SILU, qtype
|
||||||
))
|
))
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
return out + residual
|
return out + residual
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue