Add phi3RMS (#10988)

* phi3RMS
This commit is contained in:
Zhao Changmin 2024-05-14 15:16:27 +08:00 committed by GitHub
parent 170e3d65e0
commit b03c859278
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 0 deletions

View file

@ -1517,6 +1517,13 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
model_forward = model_forward_wrapper(module.Phi3Model.forward)
convert_forward(model, module.Phi3Model, model_forward)
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
convert_forward(
model,
module.Phi3RMSNorm,
phi3_rms_norm_forward)
# Empty cache after the first attention to run long context.
model.model.layers[0].self_attn.register_forward_hook(empty_cache_post)
elif model.config.model_type == 'yuan':
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -253,3 +253,17 @@ def model_forward_wrapper(origin_model_forward):
return_dict=return_dict,
)
return model_forward
def phi3_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)