support fp8 in xetla (#10555)

* support fp8 in xetla

* change name

* adjust model file

* support convert back to cpu

* factor

* fix bug

* fix style
This commit is contained in:
Yang Wang 2024-04-08 13:22:09 -07:00 committed by GitHub
parent 7c43ac0164
commit 5a1f446d3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 209 additions and 84 deletions

View file

@ -75,10 +75,12 @@ IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
Q2_K = ggml_tensor_qtype["q2_k"] Q2_K = ggml_tensor_qtype["q2_k"]
IQ1_S = ggml_tensor_qtype["gguf_iq1_s"] IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
# For sym_int4
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2. # The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
# #
# The returning weight is row major and packs two rows at a stride of 16//2. # The returning weight is row major and packs two rows at a stride of 16//2.
# 16 is the tile_size_y used in mm_int4, so that we can do something like # 16 is the tile_size_y used in mm_xetla, so that we can do something like
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4). # new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
# #
# A more complex packing strategy is to permute the weight so that the # A more complex packing strategy is to permute the weight so that the
@ -87,43 +89,97 @@ IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
# #
# Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects # Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects
# row major but packing two consecutive columns. # row major but packing two consecutive columns.
#
# For fp8, just remove the scales (which are all ones) and transpose
def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype):
if qtype == ggml_tensor_qtype["sym_int4"]:
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_0 = get_block_size("sym_int4")
n, k = weight_shape
ggml_weight_only = ggml_weight[:n*k//2]
ggml_scales = ggml_weight[n*k//2:]
qweight = ggml_weight_only.clone()
scales = ggml_scales.view(torch.float16).clone()
qweight_0 = qweight & 0x0F
qweight_1 = qweight >> 4
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
qweight = qweight.reshape(n, k//16, 2, 8)
qweight = qweight.bitwise_left_shift(
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
qweight = qweight.reshape(n, k//2)
qweight = qweight.transpose(0, 1).contiguous()
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
# 119 is the value of 0x77
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
qweight_bytes = qweight.view(torch.uint8).view(-1)
scales_bytes = scales.view(torch.uint8).view(-1)
zeros_bytes = zeros.view(torch.uint8).view(-1)
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
n, k = weight_shape
weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous()
else:
invalidInputError(False, f"Unsupported qtype {qtype}")
return weight
def q4_0_xpu_transpose(ggml_weight, weight_shape): def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype):
from ipex_llm.transformers.low_bit_linear import get_block_size from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_0 = get_block_size("sym_int4") if qtype == ggml_tensor_qtype["sym_int4"]:
Q4_0 = get_block_size("sym_int4")
n, k = weight_shape
weight_size = n*k//2
zeros_size = n*k//Q4_0//2
scales_size = n*k//Q4_0 * 2
xetla_weight_only = xetla_weight[:weight_size]
scales_start = weight_size + zeros_size
xetla_scales = xetla_weight[scales_start:scales_start+scales_size]
n, k = weight_shape qweight = xetla_weight_only.clone()
ggml_weight_only = ggml_weight[:n*k//2] scales = xetla_scales.view(torch.float16).clone()
ggml_scales = ggml_weight[n*k//2:]
qweight = ggml_weight_only.clone() qweight_0 = qweight & 0x0F
scales = ggml_scales.view(torch.float16).clone() qweight_1 = qweight >> 4
qweight_0 = qweight_0.reshape(-1, 8, n)
qweight_1 = qweight_1.reshape(-1, 8, n)
qweight = torch.cat([qweight_0, qweight_1], dim=1)
qweight_0 = qweight & 0x0F qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0,
qweight_1 = qweight >> 4 2, Q4_0//2)
qweight = qweight.bitwise_left_shift(
torch.tensor([0, 4], dtype=torch.uint8,
device=xetla_weight_only.device).reshape(1, 1, 2, 1))
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2) qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2) qweight = qweight.reshape(n, k//2)
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
qweight = qweight.reshape(n, k//16, 2, 8)
qweight = qweight.bitwise_left_shift(
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :]) scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous()
qweight = qweight.reshape(n, k//2)
qweight = qweight.transpose(0, 1).contiguous()
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous() qweight_bytes = qweight.view(torch.uint8).view(-1)
scales_bytes = scales.view(torch.uint8).view(-1)
# 119 is the value of 0x77 weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119) elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
Q8_0 = get_block_size("fp8_e5m2")
qweight_bytes = qweight.view(torch.uint8).view(-1) n, k = weight_shape
scales_bytes = scales.view(torch.uint8).view(-1) qweight = xetla_weight[:n*k].transpose(0, 1).contiguous()
zeros_bytes = zeros.view(torch.uint8).view(-1) scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device)
qweight_bytes = qweight.view(torch.uint8).view(-1)
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0) scales_bytes = scales.view(torch.uint8).view(-1)
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
else:
invalidInputError(False, f"Unsupported qtype {qtype}")
return weight return weight
@ -373,7 +429,7 @@ class FP4Params(torch.nn.Parameter):
reduce(mul, self._shape, 1), reduce(mul, self._shape, 1),
self.qtype) self.qtype)
if self.enable_xetla: if self.enable_xetla:
self.data = q4_0_xpu_transpose(self.data, self._shape) self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype)
new_param = FP4Params(super().to(device=device, new_param = FP4Params(super().to(device=device,
dtype=dtype, dtype=dtype,
non_blocking=non_blocking), non_blocking=non_blocking),
@ -397,9 +453,12 @@ class FP4Params(torch.nn.Parameter):
qtype=self.qtype, qtype=self.qtype,
enable_xetla=self.enable_xetla) enable_xetla=self.enable_xetla)
if self.enable_xetla: if self.enable_xetla:
invalidInputError(False, ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
"xetla is not supported on CPUs but got enable_xetla=True") new_param._shape,
new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data, new_param.qtype)
else:
ggml_xpu = new_param.data
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
reduce(mul, new_param._shape, 1), reduce(mul, new_param._shape, 1),
new_param.qtype) new_param.qtype)
return new_param return new_param
@ -610,7 +669,7 @@ class LowBitLinear(nn.Linear):
input_seq_size) input_seq_size)
elif self.enable_xetla: elif self.enable_xetla:
x_2d = x_2d.half() x_2d = x_2d.half()
result = linear_q4_0.mm_int4(x_2d, self.weight.data) result = linear_q4_0.mm_xetla(x_2d, self.weight.data, self.qtype)
else: else:
# inference path # inference path
# current workaround to reduce first token latency of fp32 input # current workaround to reduce first token latency of fp32 input

View file

@ -273,33 +273,46 @@ def llama_decoder_forward(
return outputs return outputs
def fuse_qkv_weight(q_proj, k_proj, v_proj): def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype):
weight_size = q_proj.out_len * q_proj.in_len // 2 if qtype == SYM_INT4:
zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64 weight_size = q_proj.out_len * q_proj.in_len // 2
zeros_end = weight_size + zeros_size zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64
weight_byte_shape = (q_proj.in_len//2, q_proj.out_len) zeros_end = weight_size + zeros_size
zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2) weight_byte_shape = (q_proj.in_len//2, q_proj.out_len)
scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2) zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2)
qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape), scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2)
k_proj.weight.data[:weight_size].reshape(weight_byte_shape), qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape),
v_proj.weight.data[:weight_size].reshape(weight_byte_shape), k_proj.weight.data[:weight_size].reshape(weight_byte_shape),
], dim=-1).reshape(-1) v_proj.weight.data[:weight_size].reshape(weight_byte_shape),
qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), ], dim=-1).reshape(-1)
k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
], dim=-1).reshape(-1) v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape), ], dim=-1).reshape(-1)
k_proj.weight.data[zeros_end:].reshape(scales_byte_shape), qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
v_proj.weight.data[zeros_end:].reshape(scales_byte_shape), k_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
], dim=-1).reshape(-1) v_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
q_proj.weight.data = torch.empty(0) ], dim=-1).reshape(-1)
k_proj.weight.data = torch.empty(0) q_proj.weight.data = torch.empty(0)
v_proj.weight.data = torch.empty(0) k_proj.weight.data = torch.empty(0)
return torch.cat([qweight, qzeros, qscales], dim=0) v_proj.weight.data = torch.empty(0)
return torch.cat([qweight, qzeros, qscales], dim=0)
elif qtype == FP8E5:
result = torch.cat([q_proj.weight, k_proj.weight, v_proj.weight], dim=1).contiguous()
q_proj.weight.data = torch.empty(0)
k_proj.weight.data = torch.empty(0)
v_proj.weight.data = torch.empty(0)
return result
else:
invalidInputError(False, f"Unsupported qtype {qtype}")
def should_use_mm_int4_qkv(self, device): def should_use_xetla_mm_qkv(self, device):
return device.type == "xpu" and self.q_proj.qtype == SYM_INT4 and self.q_proj.enable_xetla full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len
supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn
supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5
enable_xetla = self.q_proj.enable_xetla
return device.type == "xpu" and enable_xetla and supported_qtype
def llama_attention_forward_4_31( def llama_attention_forward_4_31(
@ -352,6 +365,7 @@ def llama_attention_forward_4_31_quantized(
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1) and enough_kv_room and bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
# single batch decoding fast path # single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding # forward_qkv takes will perform QKV projection, rotary position embedding
@ -553,16 +567,21 @@ def llama_attention_forward_4_31_original(
query_states, key_states, value_states query_states, key_states, value_states
) )
else: else:
if should_use_mm_int4_qkv(self, device): if should_use_xetla_mm_qkv(self, device):
if not hasattr(self, "qkv_proj_qweight"): if not hasattr(self, "qkv_proj_qweight"):
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj, self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
self.k_proj, self.k_proj,
self.v_proj) self.v_proj,
self.q_proj.weight.qtype,)
import linear_q4_0 import linear_q4_0
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight) q_out_len = self.q_proj.out_len
query_states = qkv_states[:, :, :hidden_size] k_out_len = self.k_proj.out_len
key_states = qkv_states[:, :, hidden_size:2*hidden_size] v_out_len = self.v_proj.out_len
value_states = qkv_states[:, :, 2*hidden_size:] qkv_states = linear_q4_0.mm_xetla(hidden_states, self.qkv_proj_qweight,
self.q_proj.weight.qtype)
query_states = qkv_states[:, :, :q_out_len]
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:]
else: else:
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
@ -932,6 +951,7 @@ def llama_attention_forward_4_36_quantized(
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1) and enough_kv_room and bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path: if decoding_fast_path:
hidden_states = hidden_states.view(1, -1) hidden_states = hidden_states.view(1, -1)
tmp_cache_k, tmp_cache_v = init_kv_cache( tmp_cache_k, tmp_cache_v = init_kv_cache(
@ -1196,16 +1216,22 @@ def llama_attention_forward_4_36_original(
query_states, key_states, value_states query_states, key_states, value_states
) )
else: else:
if should_use_mm_int4_qkv(self, device): if should_use_xetla_mm_qkv(self, device):
if not hasattr(self, "qkv_proj_qweight"): if not hasattr(self, "qkv_proj_qweight"):
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj, self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
self.k_proj, self.k_proj,
self.v_proj) self.v_proj,
self.q_proj.weight.qtype,)
import linear_q4_0 import linear_q4_0
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight) q_out_len = self.q_proj.out_len
query_states = qkv_states[:, :, :hidden_size] k_out_len = self.k_proj.out_len
key_states = qkv_states[:, :, hidden_size:2*hidden_size] v_out_len = self.v_proj.out_len
value_states = qkv_states[:, :, 2*hidden_size:] qkv_states = linear_q4_0.mm_xetla(hidden_states,
self.qkv_proj_qweight,
self.q_proj.weight.qtype)
query_states = qkv_states[:, :, :q_out_len]
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:]
else: else:
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)

View file

@ -54,6 +54,8 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
from ipex_llm.transformers.models.llama import fuse_qkv_weight_xetla
try: try:
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
except ImportError: except ImportError:
@ -84,7 +86,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return llama_decoding_fast_path_qtype_check(proj) and \ return llama_decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1 use_fuse_rope and enough_kv_room and bs == 1 and \
not proj.enable_xetla
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len, def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
@ -382,7 +385,6 @@ def mistral_attention_forward_original(
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len) bsz * q_len)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path: if decoding_fast_path:
hidden_states = hidden_states.view(1, -1) hidden_states = hidden_states.view(1, -1)
@ -402,9 +404,27 @@ def mistral_attention_forward_original(
self.head_dim) self.head_dim)
kv_seq_len += 1 kv_seq_len += 1
else: else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) if should_use_xetla_mm_qkv(self, device):
value_states = self.v_proj(hidden_states) if not hasattr(self, "qkv_proj_qweight"):
self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
self.k_proj,
self.v_proj,
self.q_proj.qtype)
import linear_q4_0
q_out_len = self.q_proj.out_len
k_out_len = self.k_proj.out_len
v_out_len = self.v_proj.out_len
qkv_states = linear_q4_0.mm_xetla(hidden_states,
self.qkv_proj_qweight,
self.q_proj.qtype)
query_states = qkv_states[:, :, :q_out_len]
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:]
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, key_states = key_states.view(bsz, q_len,
@ -769,9 +789,26 @@ def mistral_attention_forward_4_36_original(
past_key_value.value_cache[self.layer_idx] = value_states past_key_value.value_cache[self.layer_idx] = value_states
else: else:
query_states = self.q_proj(hidden_states) if should_use_xetla_mm_qkv(self, device):
key_states = self.k_proj(hidden_states) if not hasattr(self, "qkv_proj_qweight"):
value_states = self.v_proj(hidden_states) self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
self.k_proj,
self.v_proj,
self.q_proj.qtype)
import linear_q4_0
q_out_len = self.q_proj.out_len
k_out_len = self.k_proj.out_len
v_out_len = self.v_proj.out_len
qkv_states = linear_q4_0.mm_xetla(hidden_states,
self.qkv_proj_qweight,
self.q_proj.qtype)
query_states = qkv_states[:, :, :q_out_len]
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:]
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, key_states = key_states.view(bsz, q_len,

View file

@ -142,6 +142,7 @@ def qwen_attention_forward_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states)
qtype_check = decoding_fast_path_qtype_check(self.q_proj) qtype_check = decoding_fast_path_qtype_check(self.q_proj)
decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path: if decoding_fast_path:
hidden_states = hidden_states.view(1, -1) hidden_states = hidden_states.view(1, -1)
cache_k, cache_v = layer_past[0], layer_past[1] cache_k, cache_v = layer_past[0], layer_past[1]

View file

@ -91,6 +91,7 @@ def qwen_attention_forward_vl(
use_fuse_rope = should_use_fuse_rope(self, hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states)
qtype_check = decoding_fast_path_qtype_check(self.q_proj) qtype_check = decoding_fast_path_qtype_check(self.q_proj)
decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path: if decoding_fast_path:
hidden_states = hidden_states.view(1, -1) hidden_states = hidden_states.view(1, -1)
cache_k, cache_v = layer_past[0], layer_past[1] cache_k, cache_v = layer_past[0], layer_past[1]

View file

@ -43,7 +43,8 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return decoding_fast_path_qtype_check(proj) and \ return decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1 use_fuse_rope and enough_kv_room and bs == 1 \
and not proj.enable_xetla
def should_use_fuse_rope(self, hidden_states, position_ids): def should_use_fuse_rope(self, hidden_states, position_ids):