use fused rms norm (#9572)

* use fused rms norm

* meet code review
This commit is contained in:
Xin Qiu 2023-11-30 21:47:41 +08:00 committed by GitHub
parent b785376f5c
commit 69c49d21f5

View file

@ -74,22 +74,26 @@ def get_ipex_version():
def llama_rms_norm_forward(self, hidden_states):
optimized_rms_norm = False
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu":
if self.variance_epsilon == 1e-6:
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)],
self.weight)
optimized_rms_norm = True
else:
import linear_q4_0
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
else:
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
optimized_rms_norm = True
if not optimized_rms_norm:
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)