llm: quick fix of fast_rms_norm (#9480)

This commit is contained in:
Ruonan Wang 2023-11-16 14:42:16 +08:00 committed by GitHub
parent d5263e6681
commit c0ef70df02
2 changed files with 10 additions and 10 deletions

View file

@ -53,7 +53,7 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
[self.weight.size(0)],
self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,

View file

@ -79,7 +79,7 @@ def llama_rms_norm_forward(self, hidden_states):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,