add mlp for gemma2 (#11678)
This commit is contained in:
		
							parent
							
								
									1da1f1dd0e
								
							
						
					
					
						commit
						c02003925b
					
				
					 3 changed files with 27 additions and 2 deletions
				
			
		| 
						 | 
					@ -1513,11 +1513,13 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
 | 
					        from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
 | 
				
			||||||
        from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
 | 
					        from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
 | 
				
			||||||
        from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
 | 
					        from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward
 | 
				
			||||||
        from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
 | 
					        from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
 | 
				
			||||||
        from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
 | 
					        from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
 | 
				
			||||||
        convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
 | 
					        convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
 | 
				
			||||||
        convert_forward(model, Gemma2Attention, gemma2_attention_forward)
 | 
					        convert_forward(model, Gemma2Attention, gemma2_attention_forward)
 | 
				
			||||||
        convert_forward(model, Gemma2Model, gemma2_model_forward)
 | 
					        convert_forward(model, Gemma2Model, gemma2_model_forward)
 | 
				
			||||||
 | 
					        convert_forward(model, Gemma2MLP, gemma2_mlp_forward)
 | 
				
			||||||
    elif model.config.model_type == "Yi":
 | 
					    elif model.config.model_type == "Yi":
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,3 +41,21 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
        module.qkv_proj = qkv_proj
 | 
					        module.qkv_proj = qkv_proj
 | 
				
			||||||
        del module.q_proj, module.k_proj, module.v_proj
 | 
					        del module.q_proj, module.k_proj, module.v_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
 | 
				
			||||||
 | 
					    from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
				
			||||||
 | 
					    x_2d = x.view(-1, x.size(-1))
 | 
				
			||||||
 | 
					    qtype = getattr(module.gate_proj, "qtype", None)
 | 
				
			||||||
 | 
					    if mlp_fusion_check(x_2d, qtype, module.training):
 | 
				
			||||||
 | 
					        import xe_linear
 | 
				
			||||||
 | 
					        x_2d = x_2d.contiguous()
 | 
				
			||||||
 | 
					        return module.down_proj(
 | 
				
			||||||
 | 
					            xe_linear.mlp_forward_xpu(
 | 
				
			||||||
 | 
					                x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
 | 
				
			||||||
 | 
					                x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
 | 
				
			||||||
 | 
					                act, qtype
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,7 +34,8 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
					from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import GELU
 | 
				
			||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
 | 
					from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
 | 
				
			||||||
from transformers.cache_utils import Cache
 | 
					from transformers.cache_utils import Cache
 | 
				
			||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention
 | 
					from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention
 | 
				
			||||||
| 
						 | 
					@ -177,3 +178,7 @@ def gemma2_attention_forward(
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gemma2_mlp_forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					    return fuse_mlp_base(self, GELU, x)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue