diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index fc0766ae..37d1fba1 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1517,6 +1517,13 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi3 import model_forward_wrapper model_forward = model_forward_wrapper(module.Phi3Model.forward) convert_forward(model, module.Phi3Model, model_forward) + from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward + convert_forward( + model, + module.Phi3RMSNorm, + phi3_rms_norm_forward) + # Empty cache after the first attention to run long context. + model.model.layers[0].self_attn.register_forward_hook(empty_cache_post) elif model.config.model_type == 'yuan': modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index dba3a4e4..d8715571 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -253,3 +253,17 @@ def model_forward_wrapper(origin_model_forward): return_dict=return_dict, ) return model_forward + + +def phi3_rms_norm_forward(self, hidden_states): + if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): + import linear_q4_0 + x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() + output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon) + return output.reshape(hidden_states.shape) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype)