Use ipex fused rms norm for llama (#9081)

* also apply rmsnorm

* fix cpu
This commit is contained in:
Yang Wang 2023-10-05 12:04:55 +08:00 committed by GitHub
parent fb883100e7
commit 0cd8f1c79c
2 changed files with 18 additions and 0 deletions

View file

@ -135,6 +135,7 @@ def convert_forward(m, target_m, new_forward):
def optimize(model): def optimize(model):
from packaging import version from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
# All huggingface format models are inherited from `PreTrainedModel` # All huggingface format models are inherited from `PreTrainedModel`
@ -149,6 +150,10 @@ def optimize(model):
model, model,
transformers.models.llama.modeling_llama.LlamaAttention, transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_31,) llama_attention_forward_4_31,)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaRMSNorm,
llama_rms_norm_forward,)
else: else:
# todo implement 4.28.0 ~ 4.30.2 # todo implement 4.28.0 ~ 4.30.2
pass pass

View file

@ -57,6 +57,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
return hidden_states
def llama_attention_forward_4_31( def llama_attention_forward_4_31(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,