LLM: fix RMSNorm optimization of Baichuan2-13B/Baichuan-13B (#9204)

* fix rmsnorm of baichuan2-13B

* update baichuan1-13B too

* fix style
This commit is contained in:
Ruonan Wang 2023-10-17 18:40:34 +08:00 committed by GitHub
parent efcda3892f
commit 09815f7064
2 changed files with 29 additions and 8 deletions

View file

@ -275,19 +275,23 @@ def optimize(model):
module.Attention,
baichuan_attention_forward_7b
)
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
elif model.config.hidden_size == 5120:
# baichuan2-13B
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
convert_forward(model,
module.BaichuanAttention,
baichuan_attention_forward_13b
)
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
# baichuan2-13B's RMSNorm is a little different
convert_forward(model,
module.RMSNorm,
baichuan_13b_rms_norm_forward)
elif model.config.model_type == "baichuan":
# baichuan1
if model.config.hidden_size == 4096:
@ -299,19 +303,23 @@ def optimize(model):
module.Attention,
baichuan_attention_forward_7b
)
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
elif model.config.hidden_size == 5120:
# baichuan-13B
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
convert_forward(model,
module.BaichuanAttention,
baichuan_attention_forward_13b
)
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
# baichuan-13B's RMSNorm is a little different
convert_forward(model,
module.RMSNorm,
baichuan_13b_rms_norm_forward)
elif model.config.model_type == "gpt_neox":
from bigdl.llm.transformers.models.gptneox import gptneox_attention_forward
convert_forward(model,

View file

@ -45,6 +45,19 @@ except ImportError:
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype)
return hidden_states
def baichuan_attention_forward_7b(
self,
hidden_states: torch.Tensor,