LLM: fix RMSNorm optimization of Baichuan2-13B/Baichuan-13B (#9204)
* fix rmsnorm of baichuan2-13B * update baichuan1-13B too * fix style
This commit is contained in:
parent
efcda3892f
commit
09815f7064
2 changed files with 29 additions and 8 deletions
|
|
@ -275,19 +275,23 @@ def optimize(model):
|
|||
module.Attention,
|
||||
baichuan_attention_forward_7b
|
||||
)
|
||||
convert_forward(model,
|
||||
module.RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
elif model.config.hidden_size == 5120:
|
||||
# baichuan2-13B
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
|
||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
|
||||
convert_forward(model,
|
||||
module.BaichuanAttention,
|
||||
baichuan_attention_forward_13b
|
||||
)
|
||||
# baichuan2-13B's RMSNorm is a little different
|
||||
convert_forward(model,
|
||||
module.RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
|
||||
baichuan_13b_rms_norm_forward)
|
||||
elif model.config.model_type == "baichuan":
|
||||
# baichuan1
|
||||
if model.config.hidden_size == 4096:
|
||||
|
|
@ -299,19 +303,23 @@ def optimize(model):
|
|||
module.Attention,
|
||||
baichuan_attention_forward_7b
|
||||
)
|
||||
convert_forward(model,
|
||||
module.RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
elif model.config.hidden_size == 5120:
|
||||
# baichuan-13B
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b
|
||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
|
||||
convert_forward(model,
|
||||
module.BaichuanAttention,
|
||||
baichuan_attention_forward_13b
|
||||
)
|
||||
# baichuan-13B's RMSNorm is a little different
|
||||
convert_forward(model,
|
||||
module.RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
|
||||
baichuan_13b_rms_norm_forward)
|
||||
elif model.config.model_type == "gpt_neox":
|
||||
from bigdl.llm.transformers.models.gptneox import gptneox_attention_forward
|
||||
convert_forward(model,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,19 @@ except ImportError:
|
|||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)], self.weight)
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def baichuan_attention_forward_7b(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue