diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index b5fd1c5a..08e379c0 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -275,19 +275,23 @@ def optimize(model): module.Attention, baichuan_attention_forward_7b ) + convert_forward(model, + module.RMSNorm, + llama_rms_norm_forward) elif model.config.hidden_size == 5120: # baichuan2-13B modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b + from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b ) - convert_forward(model, - module.RMSNorm, - llama_rms_norm_forward) - + # baichuan2-13B's RMSNorm is a little different + convert_forward(model, + module.RMSNorm, + baichuan_13b_rms_norm_forward) elif model.config.model_type == "baichuan": # baichuan1 if model.config.hidden_size == 4096: @@ -299,19 +303,23 @@ def optimize(model): module.Attention, baichuan_attention_forward_7b ) + convert_forward(model, + module.RMSNorm, + llama_rms_norm_forward) elif model.config.hidden_size == 5120: # baichuan-13B modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b + from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b ) - convert_forward(model, - module.RMSNorm, - llama_rms_norm_forward) - + # baichuan-13B's RMSNorm is a little different + convert_forward(model, + module.RMSNorm, + baichuan_13b_rms_norm_forward) elif model.config.model_type == "gpt_neox": from bigdl.llm.transformers.models.gptneox import gptneox_attention_forward convert_forward(model, diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 4576fb6b..cdc49f5d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -45,6 +45,19 @@ except ImportError: KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +def baichuan_13b_rms_norm_forward(self, hidden_states): + if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight) + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + return self.weight * hidden_states.to(input_dtype) + return hidden_states + + def baichuan_attention_forward_7b( self, hidden_states: torch.Tensor,