parent
b785376f5c
commit
69c49d21f5
1 changed files with 8 additions and 4 deletions
|
|
@ -74,22 +74,26 @@ def get_ipex_version():
|
||||||
|
|
||||||
|
|
||||||
def llama_rms_norm_forward(self, hidden_states):
|
def llama_rms_norm_forward(self, hidden_states):
|
||||||
optimized_rms_norm = False
|
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
if get_ipex_version() <= "2.0.110+xpu":
|
if get_ipex_version() <= "2.0.110+xpu":
|
||||||
if self.variance_epsilon == 1e-6:
|
if self.variance_epsilon == 1e-6:
|
||||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||||
[self.weight.size(0)],
|
[self.weight.size(0)],
|
||||||
self.weight)
|
self.weight)
|
||||||
optimized_rms_norm = True
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
|
||||||
|
[self.weight.size(0)],
|
||||||
|
self.weight,
|
||||||
|
None,
|
||||||
|
self.variance_epsilon)
|
||||||
else:
|
else:
|
||||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||||
[self.weight.size(0)],
|
[self.weight.size(0)],
|
||||||
self.weight,
|
self.weight,
|
||||||
None,
|
None,
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
optimized_rms_norm = True
|
else:
|
||||||
if not optimized_rms_norm:
|
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue