LLM: fix abnormal Mistral GPU accuracy by updating rms_norm (#9529)
This commit is contained in:
parent
3d24823cda
commit
914a5a5a27
3 changed files with 24 additions and 10 deletions
|
|
@ -47,18 +47,22 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
|||
|
||||
|
||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||
optimized_rms_norm = False
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
if get_ipex_version() <= "2.0.110+xpu":
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight)
|
||||
if self.epsilon == 1e-6:
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight)
|
||||
optimized_rms_norm = True
|
||||
else:
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.epsilon)
|
||||
else:
|
||||
optimized_rms_norm = True
|
||||
if not optimized_rms_norm:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
|
|
|
|||
|
|
@ -77,10 +77,14 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
|
|||
|
||||
|
||||
def chatglm_rms_norm_forward(self, hidden_states):
|
||||
optimized_rms_norm = False
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
if get_ipex_version() <= "2.0.110+xpu":
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)], self.weight)
|
||||
if self.eps == 1e-6:
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight)
|
||||
optimized_rms_norm = True
|
||||
else:
|
||||
# for ipex >= 2.1
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
|
|
@ -88,7 +92,8 @@ def chatglm_rms_norm_forward(self, hidden_states):
|
|||
self.weight,
|
||||
None, # bias
|
||||
self.eps)
|
||||
else:
|
||||
optimized_rms_norm = True
|
||||
if not optimized_rms_norm:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
|
|
|
|||
|
|
@ -74,17 +74,22 @@ def get_ipex_version():
|
|||
|
||||
|
||||
def llama_rms_norm_forward(self, hidden_states):
|
||||
optimized_rms_norm = False
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
if get_ipex_version() <= "2.0.110+xpu":
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)], self.weight)
|
||||
if self.variance_epsilon == 1e-6:
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight)
|
||||
optimized_rms_norm = True
|
||||
else:
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
optimized_rms_norm = True
|
||||
if not optimized_rms_norm:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue