support fp8 in xetla (#10555)
* support fp8 in xetla * change name * adjust model file * support convert back to cpu * factor * fix bug * fix style
This commit is contained in:
parent
7c43ac0164
commit
5a1f446d3c
6 changed files with 209 additions and 84 deletions
|
|
@ -75,10 +75,12 @@ IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
|
|||
Q2_K = ggml_tensor_qtype["q2_k"]
|
||||
IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
|
||||
|
||||
|
||||
# For sym_int4
|
||||
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
|
||||
#
|
||||
# The returning weight is row major and packs two rows at a stride of 16//2.
|
||||
# 16 is the tile_size_y used in mm_int4, so that we can do something like
|
||||
# 16 is the tile_size_y used in mm_xetla, so that we can do something like
|
||||
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
|
||||
#
|
||||
# A more complex packing strategy is to permute the weight so that the
|
||||
|
|
@ -87,43 +89,97 @@ IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
|
|||
#
|
||||
# Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects
|
||||
# row major but packing two consecutive columns.
|
||||
#
|
||||
# For fp8, just remove the scales (which are all ones) and transpose
|
||||
def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype):
|
||||
if qtype == ggml_tensor_qtype["sym_int4"]:
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_0 = get_block_size("sym_int4")
|
||||
|
||||
n, k = weight_shape
|
||||
ggml_weight_only = ggml_weight[:n*k//2]
|
||||
ggml_scales = ggml_weight[n*k//2:]
|
||||
|
||||
qweight = ggml_weight_only.clone()
|
||||
scales = ggml_scales.view(torch.float16).clone()
|
||||
|
||||
qweight_0 = qweight & 0x0F
|
||||
qweight_1 = qweight >> 4
|
||||
|
||||
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
|
||||
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
|
||||
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
|
||||
qweight = qweight.reshape(n, k//16, 2, 8)
|
||||
qweight = qweight.bitwise_left_shift(
|
||||
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
|
||||
|
||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
||||
qweight = qweight.reshape(n, k//2)
|
||||
qweight = qweight.transpose(0, 1).contiguous()
|
||||
|
||||
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
|
||||
|
||||
# 119 is the value of 0x77
|
||||
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
|
||||
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
zeros_bytes = zeros.view(torch.uint8).view(-1)
|
||||
|
||||
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
|
||||
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
|
||||
n, k = weight_shape
|
||||
weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous()
|
||||
else:
|
||||
invalidInputError(False, f"Unsupported qtype {qtype}")
|
||||
return weight
|
||||
|
||||
|
||||
def q4_0_xpu_transpose(ggml_weight, weight_shape):
|
||||
def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype):
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_0 = get_block_size("sym_int4")
|
||||
if qtype == ggml_tensor_qtype["sym_int4"]:
|
||||
Q4_0 = get_block_size("sym_int4")
|
||||
n, k = weight_shape
|
||||
weight_size = n*k//2
|
||||
zeros_size = n*k//Q4_0//2
|
||||
scales_size = n*k//Q4_0 * 2
|
||||
xetla_weight_only = xetla_weight[:weight_size]
|
||||
scales_start = weight_size + zeros_size
|
||||
xetla_scales = xetla_weight[scales_start:scales_start+scales_size]
|
||||
|
||||
n, k = weight_shape
|
||||
ggml_weight_only = ggml_weight[:n*k//2]
|
||||
ggml_scales = ggml_weight[n*k//2:]
|
||||
qweight = xetla_weight_only.clone()
|
||||
scales = xetla_scales.view(torch.float16).clone()
|
||||
|
||||
qweight = ggml_weight_only.clone()
|
||||
scales = ggml_scales.view(torch.float16).clone()
|
||||
qweight_0 = qweight & 0x0F
|
||||
qweight_1 = qweight >> 4
|
||||
qweight_0 = qweight_0.reshape(-1, 8, n)
|
||||
qweight_1 = qweight_1.reshape(-1, 8, n)
|
||||
qweight = torch.cat([qweight_0, qweight_1], dim=1)
|
||||
|
||||
qweight_0 = qweight & 0x0F
|
||||
qweight_1 = qweight >> 4
|
||||
qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0,
|
||||
2, Q4_0//2)
|
||||
qweight = qweight.bitwise_left_shift(
|
||||
torch.tensor([0, 4], dtype=torch.uint8,
|
||||
device=xetla_weight_only.device).reshape(1, 1, 2, 1))
|
||||
|
||||
qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
|
||||
qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
|
||||
qweight = torch.cat([qweight_0, qweight_1], dim=-1)
|
||||
qweight = qweight.reshape(n, k//16, 2, 8)
|
||||
qweight = qweight.bitwise_left_shift(
|
||||
torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
|
||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
||||
qweight = qweight.reshape(n, k//2)
|
||||
|
||||
qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
|
||||
qweight = qweight.reshape(n, k//2)
|
||||
qweight = qweight.transpose(0, 1).contiguous()
|
||||
scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous()
|
||||
|
||||
scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
|
||||
|
||||
# 119 is the value of 0x77
|
||||
zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
|
||||
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
zeros_bytes = zeros.view(torch.uint8).view(-1)
|
||||
|
||||
weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
|
||||
elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
|
||||
Q8_0 = get_block_size("fp8_e5m2")
|
||||
n, k = weight_shape
|
||||
qweight = xetla_weight[:n*k].transpose(0, 1).contiguous()
|
||||
scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device)
|
||||
qweight_bytes = qweight.view(torch.uint8).view(-1)
|
||||
scales_bytes = scales.view(torch.uint8).view(-1)
|
||||
weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
|
||||
else:
|
||||
invalidInputError(False, f"Unsupported qtype {qtype}")
|
||||
return weight
|
||||
|
||||
|
||||
|
|
@ -373,7 +429,7 @@ class FP4Params(torch.nn.Parameter):
|
|||
reduce(mul, self._shape, 1),
|
||||
self.qtype)
|
||||
if self.enable_xetla:
|
||||
self.data = q4_0_xpu_transpose(self.data, self._shape)
|
||||
self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype)
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking),
|
||||
|
|
@ -397,9 +453,12 @@ class FP4Params(torch.nn.Parameter):
|
|||
qtype=self.qtype,
|
||||
enable_xetla=self.enable_xetla)
|
||||
if self.enable_xetla:
|
||||
invalidInputError(False,
|
||||
"xetla is not supported on CPUs but got enable_xetla=True")
|
||||
new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data,
|
||||
ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
|
||||
new_param._shape,
|
||||
new_param.qtype)
|
||||
else:
|
||||
ggml_xpu = new_param.data
|
||||
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
|
||||
reduce(mul, new_param._shape, 1),
|
||||
new_param.qtype)
|
||||
return new_param
|
||||
|
|
@ -610,7 +669,7 @@ class LowBitLinear(nn.Linear):
|
|||
input_seq_size)
|
||||
elif self.enable_xetla:
|
||||
x_2d = x_2d.half()
|
||||
result = linear_q4_0.mm_int4(x_2d, self.weight.data)
|
||||
result = linear_q4_0.mm_xetla(x_2d, self.weight.data, self.qtype)
|
||||
else:
|
||||
# inference path
|
||||
# current workaround to reduce first token latency of fp32 input
|
||||
|
|
|
|||
|
|
@ -273,33 +273,46 @@ def llama_decoder_forward(
|
|||
return outputs
|
||||
|
||||
|
||||
def fuse_qkv_weight(q_proj, k_proj, v_proj):
|
||||
weight_size = q_proj.out_len * q_proj.in_len // 2
|
||||
zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64
|
||||
zeros_end = weight_size + zeros_size
|
||||
weight_byte_shape = (q_proj.in_len//2, q_proj.out_len)
|
||||
zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2)
|
||||
scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2)
|
||||
qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
k_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
v_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
k_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
v_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
q_proj.weight.data = torch.empty(0)
|
||||
k_proj.weight.data = torch.empty(0)
|
||||
v_proj.weight.data = torch.empty(0)
|
||||
return torch.cat([qweight, qzeros, qscales], dim=0)
|
||||
def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype):
|
||||
if qtype == SYM_INT4:
|
||||
weight_size = q_proj.out_len * q_proj.in_len // 2
|
||||
zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64
|
||||
zeros_end = weight_size + zeros_size
|
||||
weight_byte_shape = (q_proj.in_len//2, q_proj.out_len)
|
||||
zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2)
|
||||
scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2)
|
||||
qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
k_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
v_proj.weight.data[:weight_size].reshape(weight_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
k_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
v_proj.weight.data[zeros_end:].reshape(scales_byte_shape),
|
||||
], dim=-1).reshape(-1)
|
||||
q_proj.weight.data = torch.empty(0)
|
||||
k_proj.weight.data = torch.empty(0)
|
||||
v_proj.weight.data = torch.empty(0)
|
||||
return torch.cat([qweight, qzeros, qscales], dim=0)
|
||||
elif qtype == FP8E5:
|
||||
result = torch.cat([q_proj.weight, k_proj.weight, v_proj.weight], dim=1).contiguous()
|
||||
q_proj.weight.data = torch.empty(0)
|
||||
k_proj.weight.data = torch.empty(0)
|
||||
v_proj.weight.data = torch.empty(0)
|
||||
return result
|
||||
else:
|
||||
invalidInputError(False, f"Unsupported qtype {qtype}")
|
||||
|
||||
|
||||
def should_use_mm_int4_qkv(self, device):
|
||||
return device.type == "xpu" and self.q_proj.qtype == SYM_INT4 and self.q_proj.enable_xetla
|
||||
def should_use_xetla_mm_qkv(self, device):
|
||||
full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len
|
||||
supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn
|
||||
supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5
|
||||
enable_xetla = self.q_proj.enable_xetla
|
||||
return device.type == "xpu" and enable_xetla and supported_qtype
|
||||
|
||||
|
||||
def llama_attention_forward_4_31(
|
||||
|
|
@ -352,6 +365,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
|
||||
and enough_kv_room and bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
# single batch decoding fast path
|
||||
# forward_qkv takes will perform QKV projection, rotary position embedding
|
||||
|
|
@ -553,16 +567,21 @@ def llama_attention_forward_4_31_original(
|
|||
query_states, key_states, value_states
|
||||
)
|
||||
else:
|
||||
if should_use_mm_int4_qkv(self, device):
|
||||
if should_use_xetla_mm_qkv(self, device):
|
||||
if not hasattr(self, "qkv_proj_qweight"):
|
||||
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj)
|
||||
self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.q_proj.weight.qtype,)
|
||||
import linear_q4_0
|
||||
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
|
||||
query_states = qkv_states[:, :, :hidden_size]
|
||||
key_states = qkv_states[:, :, hidden_size:2*hidden_size]
|
||||
value_states = qkv_states[:, :, 2*hidden_size:]
|
||||
q_out_len = self.q_proj.out_len
|
||||
k_out_len = self.k_proj.out_len
|
||||
v_out_len = self.v_proj.out_len
|
||||
qkv_states = linear_q4_0.mm_xetla(hidden_states, self.qkv_proj_qweight,
|
||||
self.q_proj.weight.qtype)
|
||||
query_states = qkv_states[:, :, :q_out_len]
|
||||
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
|
|
@ -932,6 +951,7 @@ def llama_attention_forward_4_36_quantized(
|
|||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
|
||||
and enough_kv_room and bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
||||
|
|
@ -1196,16 +1216,22 @@ def llama_attention_forward_4_36_original(
|
|||
query_states, key_states, value_states
|
||||
)
|
||||
else:
|
||||
if should_use_mm_int4_qkv(self, device):
|
||||
if should_use_xetla_mm_qkv(self, device):
|
||||
if not hasattr(self, "qkv_proj_qweight"):
|
||||
self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj)
|
||||
self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.q_proj.weight.qtype,)
|
||||
import linear_q4_0
|
||||
qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight)
|
||||
query_states = qkv_states[:, :, :hidden_size]
|
||||
key_states = qkv_states[:, :, hidden_size:2*hidden_size]
|
||||
value_states = qkv_states[:, :, 2*hidden_size:]
|
||||
q_out_len = self.q_proj.out_len
|
||||
k_out_len = self.k_proj.out_len
|
||||
v_out_len = self.v_proj.out_len
|
||||
qkv_states = linear_q4_0.mm_xetla(hidden_states,
|
||||
self.qkv_proj_qweight,
|
||||
self.q_proj.weight.qtype)
|
||||
query_states = qkv_states[:, :, :q_out_len]
|
||||
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
|||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
|
||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
|
||||
from ipex_llm.transformers.models.llama import fuse_qkv_weight_xetla
|
||||
try:
|
||||
from transformers.cache_utils import Cache
|
||||
except ImportError:
|
||||
|
|
@ -84,7 +86,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
|||
|
||||
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
|
||||
return llama_decoding_fast_path_qtype_check(proj) and \
|
||||
use_fuse_rope and enough_kv_room and bs == 1
|
||||
use_fuse_rope and enough_kv_room and bs == 1 and \
|
||||
not proj.enable_xetla
|
||||
|
||||
|
||||
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
|
||||
|
|
@ -382,7 +385,6 @@ def mistral_attention_forward_original(
|
|||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
|
|
@ -402,9 +404,27 @@ def mistral_attention_forward_original(
|
|||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
if should_use_xetla_mm_qkv(self, device):
|
||||
if not hasattr(self, "qkv_proj_qweight"):
|
||||
self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.q_proj.qtype)
|
||||
import linear_q4_0
|
||||
q_out_len = self.q_proj.out_len
|
||||
k_out_len = self.k_proj.out_len
|
||||
v_out_len = self.v_proj.out_len
|
||||
qkv_states = linear_q4_0.mm_xetla(hidden_states,
|
||||
self.qkv_proj_qweight,
|
||||
self.q_proj.qtype)
|
||||
query_states = qkv_states[:, :, :q_out_len]
|
||||
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
|
|
@ -769,9 +789,26 @@ def mistral_attention_forward_4_36_original(
|
|||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if should_use_xetla_mm_qkv(self, device):
|
||||
if not hasattr(self, "qkv_proj_qweight"):
|
||||
self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.q_proj.qtype)
|
||||
import linear_q4_0
|
||||
q_out_len = self.q_proj.out_len
|
||||
k_out_len = self.k_proj.out_len
|
||||
v_out_len = self.v_proj.out_len
|
||||
qkv_states = linear_q4_0.mm_xetla(hidden_states,
|
||||
self.qkv_proj_qweight,
|
||||
self.q_proj.qtype)
|
||||
query_states = qkv_states[:, :, :q_out_len]
|
||||
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
|
|
|
|||
|
|
@ -142,6 +142,7 @@ def qwen_attention_forward_original(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
||||
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
|
||||
decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ def qwen_attention_forward_vl(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
||||
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
|
||||
decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
|||
|
||||
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
|
||||
return decoding_fast_path_qtype_check(proj) and \
|
||||
use_fuse_rope and enough_kv_room and bs == 1
|
||||
use_fuse_rope and enough_kv_room and bs == 1 \
|
||||
and not proj.enable_xetla
|
||||
|
||||
|
||||
def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||
|
|
|
|||
Loading…
Reference in a new issue