parent
							
								
									170e3d65e0
								
							
						
					
					
						commit
						b03c859278
					
				
					 2 changed files with 21 additions and 0 deletions
				
			
		| 
						 | 
					@ -1517,6 +1517,13 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        from ipex_llm.transformers.models.phi3 import model_forward_wrapper
 | 
					        from ipex_llm.transformers.models.phi3 import model_forward_wrapper
 | 
				
			||||||
        model_forward = model_forward_wrapper(module.Phi3Model.forward)
 | 
					        model_forward = model_forward_wrapper(module.Phi3Model.forward)
 | 
				
			||||||
        convert_forward(model, module.Phi3Model, model_forward)
 | 
					        convert_forward(model, module.Phi3Model, model_forward)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
 | 
				
			||||||
 | 
					        convert_forward(
 | 
				
			||||||
 | 
					            model,
 | 
				
			||||||
 | 
					            module.Phi3RMSNorm,
 | 
				
			||||||
 | 
					            phi3_rms_norm_forward)
 | 
				
			||||||
 | 
					        # Empty cache after the first attention to run long context.
 | 
				
			||||||
 | 
					        model.model.layers[0].self_attn.register_forward_hook(empty_cache_post)
 | 
				
			||||||
    elif model.config.model_type == 'yuan':
 | 
					    elif model.config.model_type == 'yuan':
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -253,3 +253,17 @@ def model_forward_wrapper(origin_model_forward):
 | 
				
			||||||
            return_dict=return_dict,
 | 
					            return_dict=return_dict,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    return model_forward
 | 
					    return model_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def phi3_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
 | 
					    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
				
			||||||
 | 
					        output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
 | 
				
			||||||
 | 
					        return output.reshape(hidden_states.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					    hidden_states = hidden_states.to(torch.float32)
 | 
				
			||||||
 | 
					    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
				
			||||||
 | 
					    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 | 
				
			||||||
 | 
					    return self.weight * hidden_states.to(input_dtype)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue