Fix deepseek coder with linear rope type support on GPU (#12709)
* Fix deepseek coder with linear rope type * Style fix * Move to optimize_pre * Small fix * Small fix * Small fix to not affect other cases * Style fixes * Update function name * Small fix * Small fix * Small fix * Fix for low transformers version first * Style fix * Small fix
This commit is contained in:
parent
36bf3d8e29
commit
9d65dcd7ef
2 changed files with 15 additions and 3 deletions
|
|
@ -995,8 +995,9 @@ def _optimize_pre(model, qtype=None):
|
||||||
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
elif model.config.model_type == "llama":
|
elif model.config.model_type == "llama":
|
||||||
from ipex_llm.transformers.models.llama import merge_qkv
|
from ipex_llm.transformers.models.llama import merge_qkv, pre_compute_inv_freq
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
model.apply(pre_compute_inv_freq)
|
||||||
elif model.config.model_type == "mllama":
|
elif model.config.model_type == "mllama":
|
||||||
from ipex_llm.transformers.models.mllama import merge_qkv
|
from ipex_llm.transformers.models.mllama import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
|
|
||||||
|
|
@ -119,6 +119,13 @@ def merge_qkv(module: torch.nn.Module):
|
||||||
merge_qkv_base(module, LlamaAttention)
|
merge_qkv_base(module, LlamaAttention)
|
||||||
|
|
||||||
|
|
||||||
|
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||||
|
if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding":
|
||||||
|
if hasattr(module, "scaling_factor"):
|
||||||
|
module.register_buffer("inv_freq_scaled", None, persistent=False)
|
||||||
|
module.inv_freq_scaled = module.inv_freq / module.scaling_factor
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward(
|
def llama_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -147,8 +154,12 @@ def llama_attention_forward(
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if hasattr(self, "rotary_emb"):
|
if hasattr(self, "rotary_emb"):
|
||||||
# transformers < 4.46
|
# transformers < 4.46
|
||||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
if hasattr(self.rotary_emb, "inv_freq_scaled"):
|
||||||
query_states, key_states)
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids,
|
||||||
|
query_states, key_states)
|
||||||
|
else:
|
||||||
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
|
query_states, key_states)
|
||||||
else:
|
else:
|
||||||
# transformers >= 4.46
|
# transformers >= 4.46
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue