parent
ac2dac857c
commit
b3df47486d
1 changed files with 3 additions and 1 deletions
|
|
@ -1867,10 +1867,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention, \
|
||||||
|
Gemma2SdpaAttention
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
|
||||||
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
||||||
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
||||||
|
convert_forward(model, Gemma2SdpaAttention, gemma2_attention_forward)
|
||||||
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
||||||
convert_forward(model, Gemma2MLP, mlp_gelu_forward)
|
convert_forward(model, Gemma2MLP, mlp_gelu_forward)
|
||||||
elif model.config.model_type == "Yi":
|
elif model.config.model_type == "Yi":
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue