use fused mlp in baichuan2 (#9620)
This commit is contained in:
parent
deee65785c
commit
7319f2c227
2 changed files with 25 additions and 0 deletions
|
|
@ -485,6 +485,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_7b
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_7b
|
||||||
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Attention,
|
module.Attention,
|
||||||
baichuan_attention_forward_7b
|
baichuan_attention_forward_7b
|
||||||
|
|
@ -492,12 +493,16 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RMSNorm,
|
module.RMSNorm,
|
||||||
llama_rms_norm_forward)
|
llama_rms_norm_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.MLP,
|
||||||
|
baichuan_mlp_forward)
|
||||||
elif model.config.hidden_size == 5120:
|
elif model.config.hidden_size == 5120:
|
||||||
# baichuan2-13B
|
# baichuan2-13B
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
|
||||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
|
||||||
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.BaichuanAttention,
|
module.BaichuanAttention,
|
||||||
baichuan_attention_forward_13b
|
baichuan_attention_forward_13b
|
||||||
|
|
@ -506,6 +511,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RMSNorm,
|
module.RMSNorm,
|
||||||
baichuan_13b_rms_norm_forward)
|
baichuan_13b_rms_norm_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.MLP,
|
||||||
|
baichuan_mlp_forward)
|
||||||
elif model.config.model_type == "baichuan":
|
elif model.config.model_type == "baichuan":
|
||||||
# baichuan1
|
# baichuan1
|
||||||
if model.config.hidden_size == 4096:
|
if model.config.hidden_size == 4096:
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,23 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_mlp_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||||
|
and not (self.training and x.requires_grad):
|
||||||
|
import linear_q4_0
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
if not x_2d.is_contiguous():
|
||||||
|
x_2d = x_2d.contiguous()
|
||||||
|
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
||||||
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
|
))
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
def baichuan_attention_forward_7b(
|
def baichuan_attention_forward_7b(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue