parent
170e3d65e0
commit
b03c859278
2 changed files with 21 additions and 0 deletions
|
|
@ -1517,6 +1517,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.phi3 import model_forward_wrapper
|
||||
model_forward = model_forward_wrapper(module.Phi3Model.forward)
|
||||
convert_forward(model, module.Phi3Model, model_forward)
|
||||
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
||||
convert_forward(
|
||||
model,
|
||||
module.Phi3RMSNorm,
|
||||
phi3_rms_norm_forward)
|
||||
# Empty cache after the first attention to run long context.
|
||||
model.model.layers[0].self_attn.register_forward_hook(empty_cache_post)
|
||||
elif model.config.model_type == 'yuan':
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -253,3 +253,17 @@ def model_forward_wrapper(origin_model_forward):
|
|||
return_dict=return_dict,
|
||||
)
|
||||
return model_forward
|
||||
|
||||
|
||||
def phi3_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
import linear_q4_0
|
||||
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
|
||||
output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
|
||||
return output.reshape(hidden_states.shape)
|
||||
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue