diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 2979b1bd..a6a319eb 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -995,8 +995,9 @@ def _optimize_pre(model, qtype=None): from ipex_llm.transformers.models.gemma2 import merge_qkv model.apply(merge_qkv) elif model.config.model_type == "llama": - from ipex_llm.transformers.models.llama import merge_qkv + from ipex_llm.transformers.models.llama import merge_qkv, pre_compute_inv_freq model.apply(merge_qkv) + model.apply(pre_compute_inv_freq) elif model.config.model_type == "mllama": from ipex_llm.transformers.models.mllama import merge_qkv model.apply(merge_qkv) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 56a73290..ac101a55 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -119,6 +119,13 @@ def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, LlamaAttention) +def pre_compute_inv_freq(module: torch.nn.Module): + if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding": + if hasattr(module, "scaling_factor"): + module.register_buffer("inv_freq_scaled", None, persistent=False) + module.inv_freq_scaled = module.inv_freq / module.scaling_factor + + def llama_attention_forward( self, hidden_states: torch.Tensor, @@ -147,8 +154,12 @@ def llama_attention_forward( import xe_addons if hasattr(self, "rotary_emb"): # transformers < 4.46 - xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, - query_states, key_states) + if hasattr(self.rotary_emb, "inv_freq_scaled"): + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids, + query_states, key_states) + else: + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: # transformers >= 4.46 cos, sin = position_embeddings