support fp8 in xetla (#10555)
* support fp8 in xetla * change name * adjust model file * support convert back to cpu * factor * fix bug * fix style
This commit is contained in:
		
							parent
							
								
									7c43ac0164
								
							
						
					
					
						commit
						5a1f446d3c
					
				
					 6 changed files with 209 additions and 84 deletions
				
			
		| 
						 | 
				
			
			@ -75,10 +75,12 @@ IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
 | 
			
		|||
Q2_K = ggml_tensor_qtype["q2_k"]
 | 
			
		||||
IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# For sym_int4
 | 
			
		||||
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
 | 
			
		||||
#
 | 
			
		||||
# The returning weight is row major and packs two rows at a stride of 16//2.
 | 
			
		||||
# 16 is the tile_size_y used in mm_int4, so that we can do something like
 | 
			
		||||
# 16 is the tile_size_y used in mm_xetla, so that we can do something like
 | 
			
		||||
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
 | 
			
		||||
#
 | 
			
		||||
# A more complex packing strategy is to permute the weight so that the
 | 
			
		||||
| 
						 | 
				
			
			@ -87,9 +89,10 @@ IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
 | 
			
		|||
#
 | 
			
		||||
# Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects
 | 
			
		||||
# row major but packing two consecutive columns.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def q4_0_xpu_transpose(ggml_weight, weight_shape):
 | 
			
		||||
#
 | 
			
		||||
# For fp8, just remove the scales (which are all ones) and transpose
 | 
			
		||||
def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype):
 | 
			
		||||
    if qtype == ggml_tensor_qtype["sym_int4"]:
 | 
			
		||||
        from ipex_llm.transformers.low_bit_linear import get_block_size
 | 
			
		||||
        Q4_0 = get_block_size("sym_int4")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -124,6 +127,59 @@ def q4_0_xpu_transpose(ggml_weight, weight_shape):
 | 
			
		|||
        zeros_bytes = zeros.view(torch.uint8).view(-1)
 | 
			
		||||
 | 
			
		||||
        weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
 | 
			
		||||
    elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
 | 
			
		||||
        n, k = weight_shape
 | 
			
		||||
        weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous()
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, f"Unsupported qtype {qtype}")
 | 
			
		||||
    return weight
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import get_block_size
 | 
			
		||||
    if qtype == ggml_tensor_qtype["sym_int4"]:
 | 
			
		||||
        Q4_0 = get_block_size("sym_int4")
 | 
			
		||||
        n, k = weight_shape
 | 
			
		||||
        weight_size = n*k//2
 | 
			
		||||
        zeros_size = n*k//Q4_0//2
 | 
			
		||||
        scales_size = n*k//Q4_0 * 2
 | 
			
		||||
        xetla_weight_only = xetla_weight[:weight_size]
 | 
			
		||||
        scales_start = weight_size + zeros_size
 | 
			
		||||
        xetla_scales = xetla_weight[scales_start:scales_start+scales_size]
 | 
			
		||||
 | 
			
		||||
        qweight = xetla_weight_only.clone()
 | 
			
		||||
        scales = xetla_scales.view(torch.float16).clone()
 | 
			
		||||
 | 
			
		||||
        qweight_0 = qweight & 0x0F
 | 
			
		||||
        qweight_1 = qweight >> 4
 | 
			
		||||
        qweight_0 = qweight_0.reshape(-1, 8, n)
 | 
			
		||||
        qweight_1 = qweight_1.reshape(-1, 8, n)
 | 
			
		||||
        qweight = torch.cat([qweight_0, qweight_1], dim=1)
 | 
			
		||||
 | 
			
		||||
        qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0,
 | 
			
		||||
                                                                             2, Q4_0//2)
 | 
			
		||||
        qweight = qweight.bitwise_left_shift(
 | 
			
		||||
            torch.tensor([0, 4], dtype=torch.uint8,
 | 
			
		||||
                         device=xetla_weight_only.device).reshape(1, 1, 2, 1))
 | 
			
		||||
 | 
			
		||||
        qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
 | 
			
		||||
        qweight = qweight.reshape(n, k//2)
 | 
			
		||||
 | 
			
		||||
        scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous()
 | 
			
		||||
 | 
			
		||||
        qweight_bytes = qweight.view(torch.uint8).view(-1)
 | 
			
		||||
        scales_bytes = scales.view(torch.uint8).view(-1)
 | 
			
		||||
        weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
 | 
			
		||||
    elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
 | 
			
		||||
        Q8_0 = get_block_size("fp8_e5m2")
 | 
			
		||||
        n, k = weight_shape
 | 
			
		||||
        qweight = xetla_weight[:n*k].transpose(0, 1).contiguous()
 | 
			
		||||
        scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device)
 | 
			
		||||
        qweight_bytes = qweight.view(torch.uint8).view(-1)
 | 
			
		||||
        scales_bytes = scales.view(torch.uint8).view(-1)
 | 
			
		||||
        weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, f"Unsupported qtype {qtype}")
 | 
			
		||||
    return weight
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -373,7 +429,7 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                                     reduce(mul, self._shape, 1),
 | 
			
		||||
                                                     self.qtype)
 | 
			
		||||
            if self.enable_xetla:
 | 
			
		||||
                self.data = q4_0_xpu_transpose(self.data, self._shape)
 | 
			
		||||
                self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype)
 | 
			
		||||
            new_param = FP4Params(super().to(device=device,
 | 
			
		||||
                                             dtype=dtype,
 | 
			
		||||
                                             non_blocking=non_blocking),
 | 
			
		||||
| 
						 | 
				
			
			@ -397,9 +453,12 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                  qtype=self.qtype,
 | 
			
		||||
                                  enable_xetla=self.enable_xetla)
 | 
			
		||||
            if self.enable_xetla:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "xetla is not supported on CPUs but got enable_xetla=True")
 | 
			
		||||
            new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data,
 | 
			
		||||
                ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
 | 
			
		||||
                                                      new_param._shape,
 | 
			
		||||
                                                      new_param.qtype)
 | 
			
		||||
            else:
 | 
			
		||||
                ggml_xpu = new_param.data
 | 
			
		||||
            new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
 | 
			
		||||
                                                          reduce(mul, new_param._shape, 1),
 | 
			
		||||
                                                          new_param.qtype)
 | 
			
		||||
            return new_param
 | 
			
		||||
| 
						 | 
				
			
			@ -610,7 +669,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                                                     input_seq_size)
 | 
			
		||||
            elif self.enable_xetla:
 | 
			
		||||
                x_2d = x_2d.half()
 | 
			
		||||
                result = linear_q4_0.mm_int4(x_2d, self.weight.data)
 | 
			
		||||
                result = linear_q4_0.mm_xetla(x_2d, self.weight.data, self.qtype)
 | 
			
		||||
            else:
 | 
			
		||||
                # inference path
 | 
			
		||||
                # current workaround to reduce first token latency of fp32 input
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -273,7 +273,8 @@ def llama_decoder_forward(
 | 
			
		|||
    return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fuse_qkv_weight(q_proj, k_proj, v_proj):
 | 
			
		||||
def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype):
 | 
			
		||||
    if qtype == SYM_INT4:
 | 
			
		||||
        weight_size = q_proj.out_len * q_proj.in_len // 2
 | 
			
		||||
        zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64
 | 
			
		||||
        zeros_end = weight_size + zeros_size
 | 
			
		||||
| 
						 | 
				
			
			@ -296,10 +297,22 @@ def fuse_qkv_weight(q_proj, k_proj, v_proj):
 | 
			
		|||
        k_proj.weight.data = torch.empty(0)
 | 
			
		||||
        v_proj.weight.data = torch.empty(0)
 | 
			
		||||
        return torch.cat([qweight, qzeros, qscales], dim=0)
 | 
			
		||||
    elif qtype == FP8E5:
 | 
			
		||||
        result = torch.cat([q_proj.weight, k_proj.weight, v_proj.weight], dim=1).contiguous()
 | 
			
		||||
        q_proj.weight.data = torch.empty(0)
 | 
			
		||||
        k_proj.weight.data = torch.empty(0)
 | 
			
		||||
        v_proj.weight.data = torch.empty(0)
 | 
			
		||||
        return result
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, f"Unsupported qtype {qtype}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_mm_int4_qkv(self, device):
 | 
			
		||||
    return device.type == "xpu" and self.q_proj.qtype == SYM_INT4 and self.q_proj.enable_xetla
 | 
			
		||||
def should_use_xetla_mm_qkv(self, device):
 | 
			
		||||
    full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len
 | 
			
		||||
    supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn
 | 
			
		||||
    supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5
 | 
			
		||||
    enable_xetla = self.q_proj.enable_xetla
 | 
			
		||||
    return device.type == "xpu" and enable_xetla and supported_qtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_forward_4_31(
 | 
			
		||||
| 
						 | 
				
			
			@ -352,6 +365,7 @@ def llama_attention_forward_4_31_quantized(
 | 
			
		|||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
    decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
 | 
			
		||||
                          and enough_kv_room and bsz * q_len == 1)
 | 
			
		||||
    decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
 | 
			
		||||
 | 
			
		||||
    # single batch decoding fast path
 | 
			
		||||
    # forward_qkv takes will perform QKV projection, rotary position embedding
 | 
			
		||||
| 
						 | 
				
			
			@ -553,16 +567,21 @@ def llama_attention_forward_4_31_original(
 | 
			
		|||
                    query_states, key_states, value_states
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                if should_use_mm_int4_qkv(self, device):
 | 
			
		||||
                if should_use_xetla_mm_qkv(self, device):
 | 
			
		||||
                    if not hasattr(self, "qkv_proj_qweight"):
 | 
			
		||||
                        self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
 | 
			
		||||
                        self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
 | 
			
		||||
                                                                      self.k_proj,
 | 
			
		||||
                                                                self.v_proj)
 | 
			
		||||
                                                                      self.v_proj,
 | 
			
		||||
                                                                      self.q_proj.weight.qtype,)
 | 
			
		||||
                    import linear_q4_0
 | 
			
		||||
                    qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
 | 
			
		||||
                    query_states = qkv_states[:, :, :hidden_size]
 | 
			
		||||
                    key_states = qkv_states[:, :, hidden_size:2*hidden_size]
 | 
			
		||||
                    value_states = qkv_states[:, :, 2*hidden_size:]
 | 
			
		||||
                    q_out_len = self.q_proj.out_len
 | 
			
		||||
                    k_out_len = self.k_proj.out_len
 | 
			
		||||
                    v_out_len = self.v_proj.out_len
 | 
			
		||||
                    qkv_states = linear_q4_0.mm_xetla(hidden_states, self.qkv_proj_qweight,
 | 
			
		||||
                                                      self.q_proj.weight.qtype)
 | 
			
		||||
                    query_states = qkv_states[:, :, :q_out_len]
 | 
			
		||||
                    key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
 | 
			
		||||
                    value_states = qkv_states[:, :, q_out_len + k_out_len:]
 | 
			
		||||
                else:
 | 
			
		||||
                    query_states = self.q_proj(hidden_states)
 | 
			
		||||
                    key_states = self.k_proj(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -932,6 +951,7 @@ def llama_attention_forward_4_36_quantized(
 | 
			
		|||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
    decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
 | 
			
		||||
                          and enough_kv_room and bsz * q_len == 1)
 | 
			
		||||
    decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
 | 
			
		||||
    if decoding_fast_path:
 | 
			
		||||
        hidden_states = hidden_states.view(1, -1)
 | 
			
		||||
        tmp_cache_k, tmp_cache_v = init_kv_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -1196,16 +1216,22 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
                    query_states, key_states, value_states
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                if should_use_mm_int4_qkv(self, device):
 | 
			
		||||
                if should_use_xetla_mm_qkv(self, device):
 | 
			
		||||
                    if not hasattr(self, "qkv_proj_qweight"):
 | 
			
		||||
                        self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
 | 
			
		||||
                        self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
 | 
			
		||||
                                                                      self.k_proj,
 | 
			
		||||
                                                                self.v_proj)
 | 
			
		||||
                                                                      self.v_proj,
 | 
			
		||||
                                                                      self.q_proj.weight.qtype,)
 | 
			
		||||
                    import linear_q4_0
 | 
			
		||||
                    qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
 | 
			
		||||
                    query_states = qkv_states[:, :, :hidden_size]
 | 
			
		||||
                    key_states = qkv_states[:, :, hidden_size:2*hidden_size]
 | 
			
		||||
                    value_states = qkv_states[:, :, 2*hidden_size:]
 | 
			
		||||
                    q_out_len = self.q_proj.out_len
 | 
			
		||||
                    k_out_len = self.k_proj.out_len
 | 
			
		||||
                    v_out_len = self.v_proj.out_len
 | 
			
		||||
                    qkv_states = linear_q4_0.mm_xetla(hidden_states,
 | 
			
		||||
                                                      self.qkv_proj_qweight,
 | 
			
		||||
                                                      self.q_proj.weight.qtype)
 | 
			
		||||
                    query_states = qkv_states[:, :, :q_out_len]
 | 
			
		||||
                    key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
 | 
			
		||||
                    value_states = qkv_states[:, :, q_out_len + k_out_len:]
 | 
			
		||||
                else:
 | 
			
		||||
                    query_states = self.q_proj(hidden_states)
 | 
			
		||||
                    key_states = self.k_proj(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,6 +54,8 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		|||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
 | 
			
		||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
 | 
			
		||||
from ipex_llm.transformers.models.llama import fuse_qkv_weight_xetla
 | 
			
		||||
try:
 | 
			
		||||
    from transformers.cache_utils import Cache
 | 
			
		||||
except ImportError:
 | 
			
		||||
| 
						 | 
				
			
			@ -84,7 +86,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		|||
 | 
			
		||||
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
 | 
			
		||||
    return llama_decoding_fast_path_qtype_check(proj) and \
 | 
			
		||||
        use_fuse_rope and enough_kv_room and bs == 1
 | 
			
		||||
        use_fuse_rope and enough_kv_room and bs == 1 and \
 | 
			
		||||
        not proj.enable_xetla
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
 | 
			
		||||
| 
						 | 
				
			
			@ -382,7 +385,6 @@ def mistral_attention_forward_original(
 | 
			
		|||
                                                use_fuse_rope,
 | 
			
		||||
                                                enough_kv_room,
 | 
			
		||||
                                                bsz * q_len)
 | 
			
		||||
    decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
 | 
			
		||||
 | 
			
		||||
    if decoding_fast_path:
 | 
			
		||||
        hidden_states = hidden_states.view(1, -1)
 | 
			
		||||
| 
						 | 
				
			
			@ -402,6 +404,24 @@ def mistral_attention_forward_original(
 | 
			
		|||
                                                                         self.head_dim)
 | 
			
		||||
        kv_seq_len += 1
 | 
			
		||||
    else:
 | 
			
		||||
 | 
			
		||||
        if should_use_xetla_mm_qkv(self, device):
 | 
			
		||||
            if not hasattr(self, "qkv_proj_qweight"):
 | 
			
		||||
                self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
 | 
			
		||||
                                                              self.k_proj,
 | 
			
		||||
                                                              self.v_proj,
 | 
			
		||||
                                                              self.q_proj.qtype)
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            q_out_len = self.q_proj.out_len
 | 
			
		||||
            k_out_len = self.k_proj.out_len
 | 
			
		||||
            v_out_len = self.v_proj.out_len
 | 
			
		||||
            qkv_states = linear_q4_0.mm_xetla(hidden_states,
 | 
			
		||||
                                              self.qkv_proj_qweight,
 | 
			
		||||
                                              self.q_proj.qtype)
 | 
			
		||||
            query_states = qkv_states[:, :, :q_out_len]
 | 
			
		||||
            key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
 | 
			
		||||
            value_states = qkv_states[:, :, q_out_len + k_out_len:]
 | 
			
		||||
        else:
 | 
			
		||||
            query_states = self.q_proj(hidden_states)
 | 
			
		||||
            key_states = self.k_proj(hidden_states)
 | 
			
		||||
            value_states = self.v_proj(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -768,6 +788,23 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        if should_use_xetla_mm_qkv(self, device):
 | 
			
		||||
            if not hasattr(self, "qkv_proj_qweight"):
 | 
			
		||||
                self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
 | 
			
		||||
                                                              self.k_proj,
 | 
			
		||||
                                                              self.v_proj,
 | 
			
		||||
                                                              self.q_proj.qtype)
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            q_out_len = self.q_proj.out_len
 | 
			
		||||
            k_out_len = self.k_proj.out_len
 | 
			
		||||
            v_out_len = self.v_proj.out_len
 | 
			
		||||
            qkv_states = linear_q4_0.mm_xetla(hidden_states,
 | 
			
		||||
                                              self.qkv_proj_qweight,
 | 
			
		||||
                                              self.q_proj.qtype)
 | 
			
		||||
            query_states = qkv_states[:, :, :q_out_len]
 | 
			
		||||
            key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
 | 
			
		||||
            value_states = qkv_states[:, :, q_out_len + k_out_len:]
 | 
			
		||||
        else:
 | 
			
		||||
            query_states = self.q_proj(hidden_states)
 | 
			
		||||
            key_states = self.k_proj(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -142,6 +142,7 @@ def qwen_attention_forward_original(
 | 
			
		|||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states)
 | 
			
		||||
    qtype_check = decoding_fast_path_qtype_check(self.q_proj)
 | 
			
		||||
    decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
 | 
			
		||||
    decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
 | 
			
		||||
    if decoding_fast_path:
 | 
			
		||||
        hidden_states = hidden_states.view(1, -1)
 | 
			
		||||
        cache_k, cache_v = layer_past[0], layer_past[1]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -91,6 +91,7 @@ def qwen_attention_forward_vl(
 | 
			
		|||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states)
 | 
			
		||||
    qtype_check = decoding_fast_path_qtype_check(self.q_proj)
 | 
			
		||||
    decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
 | 
			
		||||
    decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
 | 
			
		||||
    if decoding_fast_path:
 | 
			
		||||
        hidden_states = hidden_states.view(1, -1)
 | 
			
		||||
        cache_k, cache_v = layer_past[0], layer_past[1]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,7 +43,8 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		|||
 | 
			
		||||
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
 | 
			
		||||
    return decoding_fast_path_qtype_check(proj) and \
 | 
			
		||||
        use_fuse_rope and enough_kv_room and bs == 1
 | 
			
		||||
        use_fuse_rope and enough_kv_room and bs == 1 \
 | 
			
		||||
        and not proj.enable_xetla
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue