LLM: fix RMSNorm optimization of Baichuan2-13B/Baichuan-13B (#9204)

* fix rmsnorm of baichuan2-13B

* update baichuan1-13B too

* fix style
This commit is contained in:
Ruonan Wang 2023-10-17 18:40:34 +08:00 committed by GitHub
parent efcda3892f
commit 09815f7064
2 changed files with 29 additions and 8 deletions

View file

@ -275,19 +275,23 @@ def optimize(model):
module.Attention, module.Attention,
baichuan_attention_forward_7b baichuan_attention_forward_7b
) )
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
elif model.config.hidden_size == 5120: elif model.config.hidden_size == 5120:
# baichuan2-13B # baichuan2-13B
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
convert_forward(model, convert_forward(model,
module.BaichuanAttention, module.BaichuanAttention,
baichuan_attention_forward_13b baichuan_attention_forward_13b
) )
convert_forward(model, # baichuan2-13B's RMSNorm is a little different
module.RMSNorm, convert_forward(model,
llama_rms_norm_forward) module.RMSNorm,
baichuan_13b_rms_norm_forward)
elif model.config.model_type == "baichuan": elif model.config.model_type == "baichuan":
# baichuan1 # baichuan1
if model.config.hidden_size == 4096: if model.config.hidden_size == 4096:
@ -299,19 +303,23 @@ def optimize(model):
module.Attention, module.Attention,
baichuan_attention_forward_7b baichuan_attention_forward_7b
) )
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
elif model.config.hidden_size == 5120: elif model.config.hidden_size == 5120:
# baichuan-13B # baichuan-13B
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
convert_forward(model, convert_forward(model,
module.BaichuanAttention, module.BaichuanAttention,
baichuan_attention_forward_13b baichuan_attention_forward_13b
) )
convert_forward(model, # baichuan-13B's RMSNorm is a little different
module.RMSNorm, convert_forward(model,
llama_rms_norm_forward) module.RMSNorm,
baichuan_13b_rms_norm_forward)
elif model.config.model_type == "gpt_neox": elif model.config.model_type == "gpt_neox":
from bigdl.llm.transformers.models.gptneox import gptneox_attention_forward from bigdl.llm.transformers.models.gptneox import gptneox_attention_forward
convert_forward(model, convert_forward(model,

View file

@ -45,6 +45,19 @@ except ImportError:
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype)
return hidden_states
def baichuan_attention_forward_7b( def baichuan_attention_forward_7b(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,