llm: quick fix of fast_rms_norm (#9480)
This commit is contained in:
parent
d5263e6681
commit
c0ef70df02
2 changed files with 10 additions and 10 deletions
|
|
@ -53,11 +53,11 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
|
|||
[self.weight.size(0)],
|
||||
self.weight)
|
||||
else:
|
||||
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.epsilon)
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.epsilon)
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
|
|
|||
|
|
@ -79,11 +79,11 @@ def llama_rms_norm_forward(self, hidden_states):
|
|||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)], self.weight)
|
||||
else:
|
||||
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon)
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
|
|
|||
Loading…
Reference in a new issue