add fused rms norm for Yi and Qwen (#9640)
This commit is contained in:
		
							parent
							
								
									5636b0ba80
								
							
						
					
					
						commit
						0b6f29a7fc
					
				
					 1 changed files with 10 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -580,10 +580,14 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from bigdl.llm.transformers.models.qwen import qwen_attention_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.QWenAttention,
 | 
			
		||||
                            qwen_attention_forward
 | 
			
		||||
                            )
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            chatglm_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "aquila":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			@ -606,6 +610,12 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        module.MistralRMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "Yi":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.YiRMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "whisper" and lightweight_bmm:
 | 
			
		||||
        if platform.system().lower() == 'windows':
 | 
			
		||||
            from bigdl.llm.transformers.bmm import SafeBMM
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue