use fused rms norm in chatglm2 and baichuan (#9613)

* use fused rms norm in chatglm2 and baichuan

* style fix
This commit is contained in:
Xin Qiu 2023-12-07 09:21:41 +08:00 committed by GitHub
parent 51b668f229
commit 13d47955a8
3 changed files with 23 additions and 30 deletions

View file

@ -47,28 +47,27 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_13b_rms_norm_forward(self, hidden_states): def baichuan_13b_rms_norm_forward(self, hidden_states):
optimized_rms_norm = False
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu": if get_ipex_version() <= "2.0.110+xpu":
if self.epsilon == 1e-6: import linear_q4_0
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)], [self.weight.size(0)],
self.weight) self.weight,
optimized_rms_norm = True None,
self.epsilon)
else: else:
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)], [self.weight.size(0)],
self.weight, self.weight,
None, None,
self.epsilon) self.epsilon)
optimized_rms_norm = True return hidden_states
if not optimized_rms_norm: else:
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
return hidden_states
def baichuan_attention_forward_7b( def baichuan_attention_forward_7b(

View file

@ -77,14 +77,14 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
def chatglm_rms_norm_forward(self, hidden_states): def chatglm_rms_norm_forward(self, hidden_states):
optimized_rms_norm = False
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu": if get_ipex_version() <= "2.0.110+xpu":
if self.eps == 1e-6: import linear_q4_0
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)], [self.weight.size(0)],
self.weight) self.weight,
optimized_rms_norm = True None,
self.eps)
else: else:
# for ipex >= 2.1 # for ipex >= 2.1
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
@ -92,14 +92,13 @@ def chatglm_rms_norm_forward(self, hidden_states):
self.weight, self.weight,
None, # bias None, # bias
self.eps) self.eps)
optimized_rms_norm = True return hidden_states
if not optimized_rms_norm: else:
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
return hidden_states
def chatglm2_model_forward( def chatglm2_model_forward(

View file

@ -76,30 +76,25 @@ def get_ipex_version():
def llama_rms_norm_forward(self, hidden_states): def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu": if get_ipex_version() <= "2.0.110+xpu":
if self.variance_epsilon == 1e-6: import linear_q4_0
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)], [self.weight.size(0)],
self.weight) self.weight,
else: None,
import linear_q4_0 self.variance_epsilon)
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
else: else:
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)], [self.weight.size(0)],
self.weight, self.weight,
None, None,
self.variance_epsilon) self.variance_epsilon)
return hidden_states
else: else:
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
return hidden_states
def llama_attention_forward_4_31( def llama_attention_forward_4_31(