Use ipex fused rms norm for llama (#9081)
* also apply rmsnorm * fix cpu
This commit is contained in:
		
							parent
							
								
									fb883100e7
								
							
						
					
					
						commit
						0cd8f1c79c
					
				
					 2 changed files with 18 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -135,6 +135,7 @@ def convert_forward(m, target_m, new_forward):
 | 
			
		|||
def optimize(model):
 | 
			
		||||
    from packaging import version
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
 | 
			
		||||
    # All huggingface format models are inherited from `PreTrainedModel`
 | 
			
		||||
| 
						 | 
				
			
			@ -149,6 +150,10 @@ def optimize(model):
 | 
			
		|||
            model,
 | 
			
		||||
            transformers.models.llama.modeling_llama.LlamaAttention,
 | 
			
		||||
            llama_attention_forward_4_31,)
 | 
			
		||||
        convert_forward(
 | 
			
		||||
            model,
 | 
			
		||||
            transformers.models.llama.modeling_llama.LlamaRMSNorm,
 | 
			
		||||
            llama_rms_norm_forward,)
 | 
			
		||||
    else:
 | 
			
		||||
        # todo implement 4.28.0 ~ 4.30.2
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -57,6 +57,19 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		|||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu":
 | 
			
		||||
        hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                         [self.weight.size(0)], self.weight)
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 | 
			
		||||
        return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
    return hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_forward_4_31(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue