Use ipex fused rms norm for llama (#9081)
* also apply rmsnorm * fix cpu
This commit is contained in:
parent
fb883100e7
commit
0cd8f1c79c
2 changed files with 18 additions and 0 deletions
|
|
@ -135,6 +135,7 @@ def convert_forward(m, target_m, new_forward):
|
||||||
def optimize(model):
|
def optimize(model):
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||||
|
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
# All huggingface format models are inherited from `PreTrainedModel`
|
# All huggingface format models are inherited from `PreTrainedModel`
|
||||||
|
|
@ -149,6 +150,10 @@ def optimize(model):
|
||||||
model,
|
model,
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
llama_attention_forward_4_31,)
|
llama_attention_forward_4_31,)
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm,
|
||||||
|
llama_rms_norm_forward,)
|
||||||
else:
|
else:
|
||||||
# todo implement 4.28.0 ~ 4.30.2
|
# todo implement 4.28.0 ~ 4.30.2
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
|
def llama_rms_norm_forward(self, hidden_states):
|
||||||
|
if hidden_states.device.type == "xpu":
|
||||||
|
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||||
|
[self.weight.size(0)], self.weight)
|
||||||
|
else:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward_4_31(
|
def llama_attention_forward_4_31(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue