diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 3542a7af..2bb03167 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -50,16 +50,12 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def baichuan_13b_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 - result = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.epsilon) - # if nelement == 0, means fused norm failed, go back to python implement. - if result.nelement != 0: - # We should copy this result to avoid by unknown reason on Arc GPUs. - result = result.clone() - return result + x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() + output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon) + if 1 < x_2d.size(0) <= 64: # may use XMX, need copy + output = output.clone() + return output.reshape(hidden_states.shape) + input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 7ddd80eb..e5c7bdce 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -97,16 +97,12 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten def chatglm_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 - result = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.eps) - # if nelement == 0, means fused norm failed, go back to python implement. - if result.nelement != 0: - # We should copy this result to avoid by unknown reason on Arc GPUs. - result = result.clone() - return result + x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() + output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps) + if 1 < x_2d.size(0) <= 64: # may use XMX, need copy + output = output.clone() + return output.reshape(hidden_states.shape) + input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) diff --git a/python/llm/src/bigdl/llm/transformers/models/gemma.py b/python/llm/src/bigdl/llm/transformers/models/gemma.py index 410d9c26..9400c034 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gemma.py +++ b/python/llm/src/bigdl/llm/transformers/models/gemma.py @@ -82,16 +82,12 @@ def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): def gemma_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 - result = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight + 1, - None, - self.eps) - # if nelement == 0, means fused norm failed, go back to python implement. - if result.nelement != 0: - # We should copy this result to avoid by unknown reason on Arc GPUs. - result = result.clone() - return result + x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() + output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps) + if 1 < x_2d.size(0) <= 64: # may use XMX, need copy + output = output.clone() + return output.reshape(hidden_states.shape) + input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index e8b5e8e4..f206a77d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -123,16 +123,12 @@ def llama_model_forward_4_36( def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 - result = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.variance_epsilon) - # if nelement == 0, means fused norm failed, go back to python implement. - if result.nelement != 0: - # We should copy this result to avoid by unknown reason on Arc GPUs. - result = result.clone() - return result + x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() + output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon) + if 1 < x_2d.size(0) <= 64: # may use XMX, need copy + output = output.clone() + return output.reshape(hidden_states.shape) + input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True)