Use ipex fused rms norm for llama (#9081)
* also apply rmsnorm * fix cpu
This commit is contained in:
parent
fb883100e7
commit
0cd8f1c79c
2 changed files with 18 additions and 0 deletions
|
|
@ -135,6 +135,7 @@ def convert_forward(m, target_m, new_forward):
|
|||
def optimize(model):
|
||||
from packaging import version
|
||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
# All huggingface format models are inherited from `PreTrainedModel`
|
||||
|
|
@ -149,6 +150,10 @@ def optimize(model):
|
|||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_31,)
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm,
|
||||
llama_rms_norm_forward,)
|
||||
else:
|
||||
# todo implement 4.28.0 ~ 4.30.2
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -57,6 +57,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
||||
def llama_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu":
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)], self.weight)
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def llama_attention_forward_4_31(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue