diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 27337c94..517572ab 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -17,6 +17,7 @@ import math import torch from typing import List +from ipex_llm.utils.common import invalidInputError def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: @@ -303,3 +304,56 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, ) attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug return attn_output + + +def linear_forward(x: torch.Tensor, weight: torch.Tensor, qtype: int, out_features: int): + if weight.device.type == "xpu": + new_shape = x.shape[:-1] + (out_features,) + x = x.to(weight.device, dtype=torch.float16) + x_2d = x.contiguous().view(-1, x.shape[-1]) + import xe_linear + x = xe_linear.forward_new(x_2d, weight, qtype, out_features) + x = x.view(new_shape) + return x + else: + invalidInputError(False, + "Unsupported device type: only support weight on xpu device.") + + +def quantize_linear(weight: torch.Tensor, in_features: int, precision: str): + from ipex_llm.transformers.low_bit_linear import FP4Params + from ipex_llm.ggml.quantize import ggml_tensor_qtype + + invalidInputError(precision in ggml_tensor_qtype.keys(), + f"{precision} is not supported, " + f"only {ggml_tensor_qtype.keys()} are supported now.") + qtype = ggml_tensor_qtype[precision] + paramsLowBit = FP4Params(data=weight.data, + requires_grad=False, + quantized=False, + _shape=None, + convert_shape_only=False, + qtype=qtype, + in_features=in_features, + enable_scale_search=False).to("cpu") + return paramsLowBit, qtype + + +def moe_group_topk(scores: torch.Tensor, e_score_correction_bias: torch.Tensor, + n_group: int, topk_group: int, top_k: int, norm_topk_prob: float, + routed_scaling_factor: float): + import xe_addons + topk_idx, topk_weight = xe_addons.moe_group_topk( + scores, e_score_correction_bias, + n_group, 2, topk_group, top_k, + top_k > 1 and norm_topk_prob, 1e-20, routed_scaling_factor + ) + return topk_idx, topk_weight + + +def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor, + half_layout: bool): + import xe_addons + xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, + cos, sin, half_layout) diff --git a/python/llm/src/ipex_llm/transformers/models/deepseek.py b/python/llm/src/ipex_llm/transformers/models/deepseek.py index c5edf60e..e4d1a033 100644 --- a/python/llm/src/ipex_llm/transformers/models/deepseek.py +++ b/python/llm/src/ipex_llm/transformers/models/deepseek.py @@ -228,11 +228,11 @@ def deepseek_attention_forward( [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])], dim=-1 ) - import xe_addons cos, sin = position_embeddings - xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:], - key_states[:, :, :, self.qk_nope_head_dim:], - cos, sin, True) + from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced + rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:], + key_states[:, :, :, self.qk_nope_head_dim:], + cos, sin, True) else: q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 @@ -279,11 +279,11 @@ def fuse_gate_forward(self, x: torch.Tensor): ) scores = logits.sigmoid() - import xe_addons - topk_idx, topk_weight = xe_addons.moe_group_topk( + from ipex_llm.transformers.models.common import moe_group_topk + topk_idx, topk_weight = moe_group_topk( scores, self.e_score_correction_bias, - self.n_group, 2, self.topk_group, self.top_k, - self.top_k > 1 and self.norm_topk_prob, 1e-20, self.routed_scaling_factor + self.n_group, self.topk_group, self.top_k, + self.norm_topk_prob, self.routed_scaling_factor ) else: topk_idx, topk_weight = self(x) diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index 39326567..72aba50b 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -98,9 +98,9 @@ def glm_attention_forward( cos, sin = position_embeddings if query_states.device.type == "xpu": - import xe_addons make_cache_contiguous_inplaced(cos, sin) - xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True) + from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced + rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)