Refactor some functions to ipex_llm.transformers.models.common (#13091)
				
					
				
			* add quantize_linear & linear_forward * add moe_group_topk * rotary_two_with_cache_inplaced * fix code style * update related models
This commit is contained in:
		
							parent
							
								
									73198d5b80
								
							
						
					
					
						commit
						2f78afcd2a
					
				
					 3 changed files with 64 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -17,6 +17,7 @@
 | 
			
		|||
import math
 | 
			
		||||
import torch
 | 
			
		||||
from typing import List
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
			
		||||
| 
						 | 
				
			
			@ -303,3 +304,56 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
 | 
			
		|||
            )
 | 
			
		||||
        attn_output = attn_output.to(dtype)    # workaround ipex 2.1's bug
 | 
			
		||||
        return attn_output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def linear_forward(x: torch.Tensor, weight: torch.Tensor, qtype: int, out_features: int):
 | 
			
		||||
    if weight.device.type == "xpu":
 | 
			
		||||
        new_shape = x.shape[:-1] + (out_features,)
 | 
			
		||||
        x = x.to(weight.device, dtype=torch.float16)
 | 
			
		||||
        x_2d = x.contiguous().view(-1, x.shape[-1])
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        x = xe_linear.forward_new(x_2d, weight, qtype, out_features)
 | 
			
		||||
        x = x.view(new_shape)
 | 
			
		||||
        return x
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          "Unsupported device type: only support weight on xpu device.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def quantize_linear(weight: torch.Tensor, in_features: int, precision: str):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import FP4Params
 | 
			
		||||
    from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
 | 
			
		||||
    invalidInputError(precision in ggml_tensor_qtype.keys(),
 | 
			
		||||
                      f"{precision} is not supported, "
 | 
			
		||||
                      f"only {ggml_tensor_qtype.keys()} are supported now.")
 | 
			
		||||
    qtype = ggml_tensor_qtype[precision]
 | 
			
		||||
    paramsLowBit = FP4Params(data=weight.data,
 | 
			
		||||
                             requires_grad=False,
 | 
			
		||||
                             quantized=False,
 | 
			
		||||
                             _shape=None,
 | 
			
		||||
                             convert_shape_only=False,
 | 
			
		||||
                             qtype=qtype,
 | 
			
		||||
                             in_features=in_features,
 | 
			
		||||
                             enable_scale_search=False).to("cpu")
 | 
			
		||||
    return paramsLowBit, qtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def moe_group_topk(scores: torch.Tensor, e_score_correction_bias: torch.Tensor,
 | 
			
		||||
                   n_group: int, topk_group: int, top_k: int, norm_topk_prob: float,
 | 
			
		||||
                   routed_scaling_factor: float):
 | 
			
		||||
    import xe_addons
 | 
			
		||||
    topk_idx, topk_weight = xe_addons.moe_group_topk(
 | 
			
		||||
        scores, e_score_correction_bias,
 | 
			
		||||
        n_group, 2, topk_group, top_k,
 | 
			
		||||
        top_k > 1 and norm_topk_prob, 1e-20, routed_scaling_factor
 | 
			
		||||
    )
 | 
			
		||||
    return topk_idx, topk_weight
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
 | 
			
		||||
                                   cos: torch.Tensor, sin: torch.Tensor,
 | 
			
		||||
                                   half_layout: bool):
 | 
			
		||||
    import xe_addons
 | 
			
		||||
    xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
 | 
			
		||||
                                             cos, sin, half_layout)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -228,11 +228,11 @@ def deepseek_attention_forward(
 | 
			
		|||
            [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
 | 
			
		||||
            dim=-1
 | 
			
		||||
        )
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        cos, sin = position_embeddings
 | 
			
		||||
        xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
 | 
			
		||||
                                                 key_states[:, :, :, self.qk_nope_head_dim:],
 | 
			
		||||
                                                 cos, sin, True)
 | 
			
		||||
        from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
 | 
			
		||||
        rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
 | 
			
		||||
                                       key_states[:, :, :, self.qk_nope_head_dim:],
 | 
			
		||||
                                       cos, sin, True)
 | 
			
		||||
    else:
 | 
			
		||||
        q_nope, q_pe = torch.split(
 | 
			
		||||
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
			
		||||
| 
						 | 
				
			
			@ -279,11 +279,11 @@ def fuse_gate_forward(self, x: torch.Tensor):
 | 
			
		|||
        )
 | 
			
		||||
        scores = logits.sigmoid()
 | 
			
		||||
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        topk_idx, topk_weight = xe_addons.moe_group_topk(
 | 
			
		||||
        from ipex_llm.transformers.models.common import moe_group_topk
 | 
			
		||||
        topk_idx, topk_weight = moe_group_topk(
 | 
			
		||||
            scores, self.e_score_correction_bias,
 | 
			
		||||
            self.n_group, 2, self.topk_group, self.top_k,
 | 
			
		||||
            self.top_k > 1 and self.norm_topk_prob, 1e-20, self.routed_scaling_factor
 | 
			
		||||
            self.n_group, self.topk_group, self.top_k,
 | 
			
		||||
            self.norm_topk_prob, self.routed_scaling_factor
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        topk_idx, topk_weight = self(x)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -98,9 +98,9 @@ def glm_attention_forward(
 | 
			
		|||
 | 
			
		||||
    cos, sin = position_embeddings
 | 
			
		||||
    if query_states.device.type == "xpu":
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        make_cache_contiguous_inplaced(cos, sin)
 | 
			
		||||
        xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
 | 
			
		||||
        from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
 | 
			
		||||
        rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
 | 
			
		||||
    else:
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue