diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 86fd7914..f416f161 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -135,6 +135,7 @@ def convert_forward(m, target_m, new_forward): def optimize(model): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` @@ -149,6 +150,10 @@ def optimize(model): model, transformers.models.llama.modeling_llama.LlamaAttention, llama_attention_forward_4_31,) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaRMSNorm, + llama_rms_norm_forward,) else: # todo implement 4.28.0 ~ 4.30.2 pass diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 51ddb2ee..8cbe1b0e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -57,6 +57,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +def llama_rms_norm_forward(self, hidden_states): + if hidden_states.device.type == "xpu": + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight) + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + return hidden_states + + def llama_attention_forward_4_31( self, hidden_states: torch.Tensor,