Fix deepseek coder with linear rope type support on GPU (#12709)
* Fix deepseek coder with linear rope type * Style fix * Move to optimize_pre * Small fix * Small fix * Small fix to not affect other cases * Style fixes * Update function name * Small fix * Small fix * Small fix * Fix for low transformers version first * Style fix * Small fix
This commit is contained in:
parent
36bf3d8e29
commit
9d65dcd7ef
2 changed files with 15 additions and 3 deletions
|
|
@ -995,8 +995,9 @@ def _optimize_pre(model, qtype=None):
|
|||
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
elif model.config.model_type == "llama":
|
||||
from ipex_llm.transformers.models.llama import merge_qkv
|
||||
from ipex_llm.transformers.models.llama import merge_qkv, pre_compute_inv_freq
|
||||
model.apply(merge_qkv)
|
||||
model.apply(pre_compute_inv_freq)
|
||||
elif model.config.model_type == "mllama":
|
||||
from ipex_llm.transformers.models.mllama import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
|
|
|
|||
|
|
@ -119,6 +119,13 @@ def merge_qkv(module: torch.nn.Module):
|
|||
merge_qkv_base(module, LlamaAttention)
|
||||
|
||||
|
||||
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||
if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding":
|
||||
if hasattr(module, "scaling_factor"):
|
||||
module.register_buffer("inv_freq_scaled", None, persistent=False)
|
||||
module.inv_freq_scaled = module.inv_freq / module.scaling_factor
|
||||
|
||||
|
||||
def llama_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -147,6 +154,10 @@ def llama_attention_forward(
|
|||
import xe_addons
|
||||
if hasattr(self, "rotary_emb"):
|
||||
# transformers < 4.46
|
||||
if hasattr(self.rotary_emb, "inv_freq_scaled"):
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue