From 08458b4f7444578fc9059c5b9b5f7690d41dc53d Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 19 Apr 2024 13:57:48 +0800 Subject: [PATCH] remove rms norm copy (#10793) --- python/llm/src/ipex_llm/transformers/models/baichuan2.py | 2 -- python/llm/src/ipex_llm/transformers/models/chatglm2.py | 2 -- python/llm/src/ipex_llm/transformers/models/gemma.py | 2 -- python/llm/src/ipex_llm/transformers/models/llama.py | 2 -- 4 files changed, 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index ea0a7718..51c16608 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -54,8 +54,6 @@ def baichuan_13b_rms_norm_forward(self, hidden_states): import linear_q4_0 x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon) - if 1 < x_2d.size(0) <= 64: # may use XMX, need copy - output = output.clone() return output.reshape(hidden_states.shape) input_dtype = hidden_states.dtype diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index ce4216dd..e6d4ae01 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -101,8 +101,6 @@ def chatglm_rms_norm_forward(self, hidden_states): import linear_q4_0 x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps) - if 1 < x_2d.size(0) <= 64: # may use XMX, need copy - output = output.clone() return output.reshape(hidden_states.shape) input_dtype = hidden_states.dtype diff --git a/python/llm/src/ipex_llm/transformers/models/gemma.py b/python/llm/src/ipex_llm/transformers/models/gemma.py index c99c51ec..d6ac66d2 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma.py @@ -82,8 +82,6 @@ def gemma_rms_norm_forward(self, hidden_states): import linear_q4_0 x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps) - if 1 < x_2d.size(0) <= 64: # may use XMX, need copy - output = output.clone() return output.reshape(hidden_states.shape) input_dtype = hidden_states.dtype diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 134556f9..8c1be44a 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -138,8 +138,6 @@ def llama_rms_norm_forward(self, hidden_states): import linear_q4_0 x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon) - if 1 < x_2d.size(0) <= 64: # may use XMX, need copy - output = output.clone() return output.reshape(hidden_states.shape) input_dtype = hidden_states.dtype