LLM: fix RMSNorm optimization of Baichuan2-13B/Baichuan-13B (#9204)
* fix rmsnorm of baichuan2-13B * update baichuan1-13B too * fix style
This commit is contained in:
		
							parent
							
								
									efcda3892f
								
							
						
					
					
						commit
						09815f7064
					
				
					 2 changed files with 29 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -275,19 +275,23 @@ def optimize(model):
 | 
			
		|||
                            module.Attention,
 | 
			
		||||
                            baichuan_attention_forward_7b
 | 
			
		||||
                            )
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            llama_rms_norm_forward)
 | 
			
		||||
        elif model.config.hidden_size == 5120:
 | 
			
		||||
            # baichuan2-13B
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
 | 
			
		||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.BaichuanAttention,
 | 
			
		||||
                            baichuan_attention_forward_13b
 | 
			
		||||
                            )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.RMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward)
 | 
			
		||||
 | 
			
		||||
            # baichuan2-13B's RMSNorm is a little different
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            baichuan_13b_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "baichuan":
 | 
			
		||||
        # baichuan1
 | 
			
		||||
        if model.config.hidden_size == 4096:
 | 
			
		||||
| 
						 | 
				
			
			@ -299,19 +303,23 @@ def optimize(model):
 | 
			
		|||
                            module.Attention,
 | 
			
		||||
                            baichuan_attention_forward_7b
 | 
			
		||||
                            )
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            llama_rms_norm_forward)
 | 
			
		||||
        elif model.config.hidden_size == 5120:
 | 
			
		||||
            # baichuan-13B
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b
 | 
			
		||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.BaichuanAttention,
 | 
			
		||||
                            baichuan_attention_forward_13b
 | 
			
		||||
                            )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.RMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward)
 | 
			
		||||
 | 
			
		||||
            # baichuan-13B's RMSNorm is a little different
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            baichuan_13b_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "gpt_neox":
 | 
			
		||||
        from bigdl.llm.transformers.models.gptneox import gptneox_attention_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,6 +45,19 @@ except ImportError:
 | 
			
		|||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                         [self.weight.size(0)], self.weight)
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
 | 
			
		||||
        return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
    return hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_attention_forward_7b(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue