Refactor some functions to ipex_llm.transformers.models.common (#13091)
* add quantize_linear & linear_forward * add moe_group_topk * rotary_two_with_cache_inplaced * fix code style * update related models
This commit is contained in:
parent
73198d5b80
commit
2f78afcd2a
3 changed files with 64 additions and 10 deletions
|
|
@ -17,6 +17,7 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
||||||
|
|
@ -303,3 +304,56 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
|
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def linear_forward(x: torch.Tensor, weight: torch.Tensor, qtype: int, out_features: int):
|
||||||
|
if weight.device.type == "xpu":
|
||||||
|
new_shape = x.shape[:-1] + (out_features,)
|
||||||
|
x = x.to(weight.device, dtype=torch.float16)
|
||||||
|
x_2d = x.contiguous().view(-1, x.shape[-1])
|
||||||
|
import xe_linear
|
||||||
|
x = xe_linear.forward_new(x_2d, weight, qtype, out_features)
|
||||||
|
x = x.view(new_shape)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
"Unsupported device type: only support weight on xpu device.")
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_linear(weight: torch.Tensor, in_features: int, precision: str):
|
||||||
|
from ipex_llm.transformers.low_bit_linear import FP4Params
|
||||||
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
||||||
|
invalidInputError(precision in ggml_tensor_qtype.keys(),
|
||||||
|
f"{precision} is not supported, "
|
||||||
|
f"only {ggml_tensor_qtype.keys()} are supported now.")
|
||||||
|
qtype = ggml_tensor_qtype[precision]
|
||||||
|
paramsLowBit = FP4Params(data=weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=False,
|
||||||
|
_shape=None,
|
||||||
|
convert_shape_only=False,
|
||||||
|
qtype=qtype,
|
||||||
|
in_features=in_features,
|
||||||
|
enable_scale_search=False).to("cpu")
|
||||||
|
return paramsLowBit, qtype
|
||||||
|
|
||||||
|
|
||||||
|
def moe_group_topk(scores: torch.Tensor, e_score_correction_bias: torch.Tensor,
|
||||||
|
n_group: int, topk_group: int, top_k: int, norm_topk_prob: float,
|
||||||
|
routed_scaling_factor: float):
|
||||||
|
import xe_addons
|
||||||
|
topk_idx, topk_weight = xe_addons.moe_group_topk(
|
||||||
|
scores, e_score_correction_bias,
|
||||||
|
n_group, 2, topk_group, top_k,
|
||||||
|
top_k > 1 and norm_topk_prob, 1e-20, routed_scaling_factor
|
||||||
|
)
|
||||||
|
return topk_idx, topk_weight
|
||||||
|
|
||||||
|
|
||||||
|
def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor, sin: torch.Tensor,
|
||||||
|
half_layout: bool):
|
||||||
|
import xe_addons
|
||||||
|
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
|
||||||
|
cos, sin, half_layout)
|
||||||
|
|
|
||||||
|
|
@ -228,9 +228,9 @@ def deepseek_attention_forward(
|
||||||
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
|
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
|
||||||
dim=-1
|
dim=-1
|
||||||
)
|
)
|
||||||
import xe_addons
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
|
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
|
||||||
|
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
|
||||||
key_states[:, :, :, self.qk_nope_head_dim:],
|
key_states[:, :, :, self.qk_nope_head_dim:],
|
||||||
cos, sin, True)
|
cos, sin, True)
|
||||||
else:
|
else:
|
||||||
|
|
@ -279,11 +279,11 @@ def fuse_gate_forward(self, x: torch.Tensor):
|
||||||
)
|
)
|
||||||
scores = logits.sigmoid()
|
scores = logits.sigmoid()
|
||||||
|
|
||||||
import xe_addons
|
from ipex_llm.transformers.models.common import moe_group_topk
|
||||||
topk_idx, topk_weight = xe_addons.moe_group_topk(
|
topk_idx, topk_weight = moe_group_topk(
|
||||||
scores, self.e_score_correction_bias,
|
scores, self.e_score_correction_bias,
|
||||||
self.n_group, 2, self.topk_group, self.top_k,
|
self.n_group, self.topk_group, self.top_k,
|
||||||
self.top_k > 1 and self.norm_topk_prob, 1e-20, self.routed_scaling_factor
|
self.norm_topk_prob, self.routed_scaling_factor
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx, topk_weight = self(x)
|
topk_idx, topk_weight = self(x)
|
||||||
|
|
|
||||||
|
|
@ -98,9 +98,9 @@ def glm_attention_forward(
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
if query_states.device.type == "xpu":
|
if query_states.device.type == "xpu":
|
||||||
import xe_addons
|
|
||||||
make_cache_contiguous_inplaced(cos, sin)
|
make_cache_contiguous_inplaced(cos, sin)
|
||||||
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
|
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
|
||||||
|
rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
|
||||||
else:
|
else:
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue