Fix deepseek coder with linear rope type support on GPU (#12709)

* Fix deepseek coder with linear rope type

* Style fix

* Move to optimize_pre

* Small fix

* Small fix

* Small fix to not affect other cases

* Style fixes

* Update function name

* Small fix

* Small fix

* Small fix

* Fix for low transformers version first

* Style fix

* Small fix
This commit is contained in:
Yuwen Hu 2025-01-15 21:12:34 +08:00 committed by GitHub
parent 36bf3d8e29
commit 9d65dcd7ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View file

@ -995,8 +995,9 @@ def _optimize_pre(model, qtype=None):
from ipex_llm.transformers.models.gemma2 import merge_qkv from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
elif model.config.model_type == "llama": elif model.config.model_type == "llama":
from ipex_llm.transformers.models.llama import merge_qkv from ipex_llm.transformers.models.llama import merge_qkv, pre_compute_inv_freq
model.apply(merge_qkv) model.apply(merge_qkv)
model.apply(pre_compute_inv_freq)
elif model.config.model_type == "mllama": elif model.config.model_type == "mllama":
from ipex_llm.transformers.models.mllama import merge_qkv from ipex_llm.transformers.models.mllama import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)

View file

@ -119,6 +119,13 @@ def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, LlamaAttention) merge_qkv_base(module, LlamaAttention)
def pre_compute_inv_freq(module: torch.nn.Module):
if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding":
if hasattr(module, "scaling_factor"):
module.register_buffer("inv_freq_scaled", None, persistent=False)
module.inv_freq_scaled = module.inv_freq / module.scaling_factor
def llama_attention_forward( def llama_attention_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -147,6 +154,10 @@ def llama_attention_forward(
import xe_addons import xe_addons
if hasattr(self, "rotary_emb"): if hasattr(self, "rotary_emb"):
# transformers < 4.46 # transformers < 4.46
if hasattr(self.rotary_emb, "inv_freq_scaled"):
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids,
query_states, key_states)
else:
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states) query_states, key_states)
else: else: