diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index ff34951b..24864b4b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -1108,6 +1108,7 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.gemma import gemma_attention_forward from bigdl.llm.transformers.models.gemma import gemma_rms_norm_forward + from bigdl.llm.transformers.models.gemma import gemma_mlp_forward convert_forward(model, module.GemmaAttention, gemma_attention_forward, @@ -1115,6 +1116,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.GemmaRMSNorm, gemma_rms_norm_forward) + convert_forward(model, + module.GemmaMLP, + gemma_mlp_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 61eb7038..983caaa6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -28,7 +28,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv restore_fp8_kv_cache, use_quantize_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 -from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, SILU from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import mlp_fusion_check from transformers.utils import logging @@ -80,7 +80,7 @@ def baichuan_mlp_forward( return self.down_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, - qtype + SILU, qtype )) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) diff --git a/python/llm/src/bigdl/llm/transformers/models/gemma.py b/python/llm/src/bigdl/llm/transformers/models/gemma.py index 1fcb3cd5..410d9c26 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gemma.py +++ b/python/llm/src/bigdl/llm/transformers/models/gemma.py @@ -38,6 +38,7 @@ from torch import nn from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu +from bigdl.llm.transformers.models.utils import mlp_fusion_check, GELU from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 @@ -98,6 +99,31 @@ def gemma_rms_norm_forward(self, hidden_states): return (1 + self.weight) * hidden_states.to(input_dtype) +def gemma_mlp_forward( + self, + x: torch.Tensor, + residual=None +) -> torch.Tensor: + x_2d = x.view(-1, x.shape[-1]) + bsz, hidden_size = x_2d.shape + qtype = getattr(self.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: + import linear_q4_0 + if not x_2d.is_contiguous(): + x_2d = x_2d.contiguous() + out = self.down_proj(linear_q4_0.mlp_forward_xpu( + x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, + GELU, qtype + )) + else: + out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + if residual is not None: + return out + residual + else: + return out + + def gemma_attention_forward( self, hidden_states: torch.Tensor, @@ -136,6 +162,7 @@ def gemma_attention_forward( position_ids, cache_k, cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, kv_seq_len, self.head_dim) kv_seq_len += 1 diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index b58c23eb..bc9f19d9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -40,6 +40,7 @@ import math import os import torch.nn.functional as F from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import SILU from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ @@ -118,7 +119,7 @@ def llama_mlp_forward( out = self.down_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, - qtype + SILU, qtype )) if residual is not None: return out + residual diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index 53b8e114..0bffab9d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -50,7 +50,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp -from bigdl.llm.transformers.models.utils import mlp_fusion_check +from bigdl.llm.transformers.models.utils import mlp_fusion_check, SILU from bigdl.llm.transformers.low_bit_linear import IQ2_XXS @@ -371,7 +371,7 @@ def mixtral_mlp_forward( return self.w2(linear_q4_0.mlp_forward_xpu( x, self.w1.weight.data, self.w3.weight.data, x.shape[0], x.shape[1], self.w1.out_len, - qtype, + SILU, qtype, )) * routing_weights else: current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 1d5a7bec..92f37ac1 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -39,7 +39,7 @@ except ImportError: from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache -from bigdl.llm.transformers.models.utils import rotate_half +from bigdl.llm.transformers.models.utils import rotate_half, SILU from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from bigdl.llm.utils.common import invalidInputError, invalidOperationError @@ -292,6 +292,6 @@ def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: return self.c_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.w2.weight.data, self.w1.weight.data, x_2d.shape[0], x_2d.shape[1], self.w2.out_len, - qtype + SILU, qtype )) return self.c_proj(F.silu(self.w2(x)) * self.w1(x)) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index de320e9b..bc79942e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -26,6 +26,10 @@ SYM_INT8 = ggml_tensor_qtype["sym_int8"] FP8E4 = ggml_tensor_qtype["fp8_e4m3"] FP8E5 = ggml_tensor_qtype["fp8_e5m2"] +# used in fused mlp forward +SILU = 0 +GELU = 1 + def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): key_cache_storage = torch.empty(batch_size, num_heads, diff --git a/python/llm/src/bigdl/llm/transformers/models/yuan.py b/python/llm/src/bigdl/llm/transformers/models/yuan.py index abd09442..e0ba1f6a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/yuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/yuan.py @@ -34,7 +34,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache -from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31 +from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31. SILU from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -107,7 +107,7 @@ def yuan_mlp_forward( out = self.down_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.up_proj.weight.data, self.gate_proj.weight.data, x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len, - qtype + SILU, qtype )) if residual is not None: return out + residual