Enable Gemma fused mlp + Gelu (#10276)

* update llama mlp forward

* add all

* fix style check

* split

* update

* update

* update

* fix style
This commit is contained in:
Xin Qiu 2024-02-29 16:53:24 +08:00 committed by GitHub
parent 2d930bdca8
commit 232273a1b5
8 changed files with 45 additions and 9 deletions

View file

@ -1108,6 +1108,7 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.gemma import gemma_attention_forward
from bigdl.llm.transformers.models.gemma import gemma_rms_norm_forward
from bigdl.llm.transformers.models.gemma import gemma_mlp_forward
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward,
@ -1115,6 +1116,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.GemmaRMSNorm,
gemma_rms_norm_forward)
convert_forward(model,
module.GemmaMLP,
gemma_mlp_forward)
elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -28,7 +28,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv
restore_fp8_kv_cache, use_quantize_kv_cache
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, SILU
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from transformers.utils import logging
@ -80,7 +80,7 @@ def baichuan_mlp_forward(
return self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
qtype
SILU, qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

View file

@ -38,6 +38,7 @@ from torch import nn
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from bigdl.llm.transformers.models.utils import mlp_fusion_check, GELU
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
@ -98,6 +99,31 @@ def gemma_rms_norm_forward(self, hidden_states):
return (1 + self.weight) * hidden_states.to(input_dtype)
def gemma_mlp_forward(
self,
x: torch.Tensor,
residual=None
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
bsz, hidden_size = x_2d.shape
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
GELU, qtype
))
else:
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
if residual is not None:
return out + residual
else:
return out
def gemma_attention_forward(
self,
hidden_states: torch.Tensor,
@ -136,6 +162,7 @@ def gemma_attention_forward(
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1

View file

@ -40,6 +40,7 @@ import math
import os
import torch.nn.functional as F
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import SILU
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
@ -118,7 +119,7 @@ def llama_mlp_forward(
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
qtype
SILU, qtype
))
if residual is not None:
return out + residual

View file

@ -50,7 +50,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.transformers.models.utils import mlp_fusion_check, SILU
from bigdl.llm.transformers.low_bit_linear import IQ2_XXS
@ -371,7 +371,7 @@ def mixtral_mlp_forward(
return self.w2(linear_q4_0.mlp_forward_xpu(
x, self.w1.weight.data, self.w3.weight.data,
x.shape[0], x.shape[1], self.w1.out_len,
qtype,
SILU, qtype,
)) * routing_weights
else:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)

View file

@ -39,7 +39,7 @@ except ImportError:
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half
from bigdl.llm.transformers.models.utils import rotate_half, SILU
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
@ -292,6 +292,6 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
return self.c_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.w2.weight.data, self.w1.weight.data,
x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
qtype
SILU, qtype
))
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))

View file

@ -26,6 +26,10 @@ SYM_INT8 = ggml_tensor_qtype["sym_int8"]
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
# used in fused mlp forward
SILU = 0
GELU = 1
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
key_cache_storage = torch.empty(batch_size, num_heads,

View file

@ -34,7 +34,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31. SILU
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -107,7 +107,7 @@ def yuan_mlp_forward(
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
qtype
SILU, qtype
))
if residual is not None:
return out + residual