From 5a1f446d3c65e4491b7570e8654e72c128abd47c Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 8 Apr 2024 13:22:09 -0700 Subject: [PATCH] support fp8 in xetla (#10555) * support fp8 in xetla * change name * adjust model file * support convert back to cpu * factor * fix bug * fix style --- .../ipex_llm/transformers/low_bit_linear.py | 127 +++++++++++++----- .../src/ipex_llm/transformers/models/llama.py | 108 +++++++++------ .../ipex_llm/transformers/models/mistral.py | 53 ++++++-- .../src/ipex_llm/transformers/models/qwen.py | 1 + .../ipex_llm/transformers/models/qwen_vl.py | 1 + .../src/ipex_llm/transformers/models/yuan.py | 3 +- 6 files changed, 209 insertions(+), 84 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 96f57167..0acca42a 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -75,10 +75,12 @@ IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"] Q2_K = ggml_tensor_qtype["q2_k"] IQ1_S = ggml_tensor_qtype["gguf_iq1_s"] + +# For sym_int4 # The ggml_weight is col major and packs two rows at a stride of Q4_0//2. # # The returning weight is row major and packs two rows at a stride of 16//2. -# 16 is the tile_size_y used in mm_int4, so that we can do something like +# 16 is the tile_size_y used in mm_xetla, so that we can do something like # new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4). # # A more complex packing strategy is to permute the weight so that the @@ -87,43 +89,97 @@ IQ1_S = ggml_tensor_qtype["gguf_iq1_s"] # # Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects # row major but packing two consecutive columns. +# +# For fp8, just remove the scales (which are all ones) and transpose +def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype): + if qtype == ggml_tensor_qtype["sym_int4"]: + from ipex_llm.transformers.low_bit_linear import get_block_size + Q4_0 = get_block_size("sym_int4") + + n, k = weight_shape + ggml_weight_only = ggml_weight[:n*k//2] + ggml_scales = ggml_weight[n*k//2:] + + qweight = ggml_weight_only.clone() + scales = ggml_scales.view(torch.float16).clone() + + qweight_0 = qweight & 0x0F + qweight_1 = qweight >> 4 + + qweight_0 = qweight_0.reshape(n, -1, Q4_0//2) + qweight_1 = qweight_1.reshape(n, -1, Q4_0//2) + qweight = torch.cat([qweight_0, qweight_1], dim=-1) + qweight = qweight.reshape(n, k//16, 2, 8) + qweight = qweight.bitwise_left_shift( + torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1)) + + qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :]) + qweight = qweight.reshape(n, k//2) + qweight = qweight.transpose(0, 1).contiguous() + + scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous() + + # 119 is the value of 0x77 + zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119) + + qweight_bytes = qweight.view(torch.uint8).view(-1) + scales_bytes = scales.view(torch.uint8).view(-1) + zeros_bytes = zeros.view(torch.uint8).view(-1) + + weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0) + elif qtype == ggml_tensor_qtype["fp8_e5m2"]: + n, k = weight_shape + weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous() + else: + invalidInputError(False, f"Unsupported qtype {qtype}") + return weight -def q4_0_xpu_transpose(ggml_weight, weight_shape): +def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype): from ipex_llm.transformers.low_bit_linear import get_block_size - Q4_0 = get_block_size("sym_int4") + if qtype == ggml_tensor_qtype["sym_int4"]: + Q4_0 = get_block_size("sym_int4") + n, k = weight_shape + weight_size = n*k//2 + zeros_size = n*k//Q4_0//2 + scales_size = n*k//Q4_0 * 2 + xetla_weight_only = xetla_weight[:weight_size] + scales_start = weight_size + zeros_size + xetla_scales = xetla_weight[scales_start:scales_start+scales_size] - n, k = weight_shape - ggml_weight_only = ggml_weight[:n*k//2] - ggml_scales = ggml_weight[n*k//2:] + qweight = xetla_weight_only.clone() + scales = xetla_scales.view(torch.float16).clone() - qweight = ggml_weight_only.clone() - scales = ggml_scales.view(torch.float16).clone() + qweight_0 = qweight & 0x0F + qweight_1 = qweight >> 4 + qweight_0 = qweight_0.reshape(-1, 8, n) + qweight_1 = qweight_1.reshape(-1, 8, n) + qweight = torch.cat([qweight_0, qweight_1], dim=1) - qweight_0 = qweight & 0x0F - qweight_1 = qweight >> 4 + qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0, + 2, Q4_0//2) + qweight = qweight.bitwise_left_shift( + torch.tensor([0, 4], dtype=torch.uint8, + device=xetla_weight_only.device).reshape(1, 1, 2, 1)) - qweight_0 = qweight_0.reshape(n, -1, Q4_0//2) - qweight_1 = qweight_1.reshape(n, -1, Q4_0//2) - qweight = torch.cat([qweight_0, qweight_1], dim=-1) - qweight = qweight.reshape(n, k//16, 2, 8) - qweight = qweight.bitwise_left_shift( - torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1)) + qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :]) + qweight = qweight.reshape(n, k//2) - qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :]) - qweight = qweight.reshape(n, k//2) - qweight = qweight.transpose(0, 1).contiguous() + scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous() - scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous() - - # 119 is the value of 0x77 - zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119) - - qweight_bytes = qweight.view(torch.uint8).view(-1) - scales_bytes = scales.view(torch.uint8).view(-1) - zeros_bytes = zeros.view(torch.uint8).view(-1) - - weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0) + qweight_bytes = qweight.view(torch.uint8).view(-1) + scales_bytes = scales.view(torch.uint8).view(-1) + weight = torch.concat([qweight_bytes, scales_bytes], dim=0) + elif qtype == ggml_tensor_qtype["fp8_e5m2"]: + Q8_0 = get_block_size("fp8_e5m2") + n, k = weight_shape + qweight = xetla_weight[:n*k].transpose(0, 1).contiguous() + scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device) + qweight_bytes = qweight.view(torch.uint8).view(-1) + scales_bytes = scales.view(torch.uint8).view(-1) + weight = torch.concat([qweight_bytes, scales_bytes], dim=0) + else: + invalidInputError(False, f"Unsupported qtype {qtype}") return weight @@ -373,7 +429,7 @@ class FP4Params(torch.nn.Parameter): reduce(mul, self._shape, 1), self.qtype) if self.enable_xetla: - self.data = q4_0_xpu_transpose(self.data, self._shape) + self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype) new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), @@ -397,9 +453,12 @@ class FP4Params(torch.nn.Parameter): qtype=self.qtype, enable_xetla=self.enable_xetla) if self.enable_xetla: - invalidInputError(False, - "xetla is not supported on CPUs but got enable_xetla=True") - new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data, + ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data, + new_param._shape, + new_param.qtype) + else: + ggml_xpu = new_param.data + new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu, reduce(mul, new_param._shape, 1), new_param.qtype) return new_param @@ -610,7 +669,7 @@ class LowBitLinear(nn.Linear): input_seq_size) elif self.enable_xetla: x_2d = x_2d.half() - result = linear_q4_0.mm_int4(x_2d, self.weight.data) + result = linear_q4_0.mm_xetla(x_2d, self.weight.data, self.qtype) else: # inference path # current workaround to reduce first token latency of fp32 input diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index ee367131..f7531561 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -273,33 +273,46 @@ def llama_decoder_forward( return outputs -def fuse_qkv_weight(q_proj, k_proj, v_proj): - weight_size = q_proj.out_len * q_proj.in_len // 2 - zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64 - zeros_end = weight_size + zeros_size - weight_byte_shape = (q_proj.in_len//2, q_proj.out_len) - zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2) - scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2) - qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape), - k_proj.weight.data[:weight_size].reshape(weight_byte_shape), - v_proj.weight.data[:weight_size].reshape(weight_byte_shape), - ], dim=-1).reshape(-1) - qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), - k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), - v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), - ], dim=-1).reshape(-1) - qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape), - k_proj.weight.data[zeros_end:].reshape(scales_byte_shape), - v_proj.weight.data[zeros_end:].reshape(scales_byte_shape), - ], dim=-1).reshape(-1) - q_proj.weight.data = torch.empty(0) - k_proj.weight.data = torch.empty(0) - v_proj.weight.data = torch.empty(0) - return torch.cat([qweight, qzeros, qscales], dim=0) +def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype): + if qtype == SYM_INT4: + weight_size = q_proj.out_len * q_proj.in_len // 2 + zeros_size = q_proj.in_len * q_proj.out_len // 2 // 64 + zeros_end = weight_size + zeros_size + weight_byte_shape = (q_proj.in_len//2, q_proj.out_len) + zeros_byte_shape = (q_proj.in_len//64, q_proj.out_len//2) + scales_byte_shape = (q_proj.in_len//64, q_proj.out_len*2) + qweight = torch.concat([q_proj.weight.data[:weight_size].reshape(weight_byte_shape), + k_proj.weight.data[:weight_size].reshape(weight_byte_shape), + v_proj.weight.data[:weight_size].reshape(weight_byte_shape), + ], dim=-1).reshape(-1) + qzeros = torch.concat([q_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), + k_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), + v_proj.weight.data[weight_size:zeros_end].reshape(zeros_byte_shape), + ], dim=-1).reshape(-1) + qscales = torch.concat([q_proj.weight.data[zeros_end:].reshape(scales_byte_shape), + k_proj.weight.data[zeros_end:].reshape(scales_byte_shape), + v_proj.weight.data[zeros_end:].reshape(scales_byte_shape), + ], dim=-1).reshape(-1) + q_proj.weight.data = torch.empty(0) + k_proj.weight.data = torch.empty(0) + v_proj.weight.data = torch.empty(0) + return torch.cat([qweight, qzeros, qscales], dim=0) + elif qtype == FP8E5: + result = torch.cat([q_proj.weight, k_proj.weight, v_proj.weight], dim=1).contiguous() + q_proj.weight.data = torch.empty(0) + k_proj.weight.data = torch.empty(0) + v_proj.weight.data = torch.empty(0) + return result + else: + invalidInputError(False, f"Unsupported qtype {qtype}") -def should_use_mm_int4_qkv(self, device): - return device.type == "xpu" and self.q_proj.qtype == SYM_INT4 and self.q_proj.enable_xetla +def should_use_xetla_mm_qkv(self, device): + full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len + supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn + supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5 + enable_xetla = self.q_proj.enable_xetla + return device.type == "xpu" and enable_xetla and supported_qtype def llama_attention_forward_4_31( @@ -352,6 +365,7 @@ def llama_attention_forward_4_31_quantized( no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) + decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla # single batch decoding fast path # forward_qkv takes will perform QKV projection, rotary position embedding @@ -553,16 +567,21 @@ def llama_attention_forward_4_31_original( query_states, key_states, value_states ) else: - if should_use_mm_int4_qkv(self, device): + if should_use_xetla_mm_qkv(self, device): if not hasattr(self, "qkv_proj_qweight"): - self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj, - self.k_proj, - self.v_proj) + self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj, + self.k_proj, + self.v_proj, + self.q_proj.weight.qtype,) import linear_q4_0 - qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight) - query_states = qkv_states[:, :, :hidden_size] - key_states = qkv_states[:, :, hidden_size:2*hidden_size] - value_states = qkv_states[:, :, 2*hidden_size:] + q_out_len = self.q_proj.out_len + k_out_len = self.k_proj.out_len + v_out_len = self.v_proj.out_len + qkv_states = linear_q4_0.mm_xetla(hidden_states, self.qkv_proj_qweight, + self.q_proj.weight.qtype) + query_states = qkv_states[:, :, :q_out_len] + key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] + value_states = qkv_states[:, :, q_out_len + k_out_len:] else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -932,6 +951,7 @@ def llama_attention_forward_4_36_quantized( no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) + decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla if decoding_fast_path: hidden_states = hidden_states.view(1, -1) tmp_cache_k, tmp_cache_v = init_kv_cache( @@ -1196,16 +1216,22 @@ def llama_attention_forward_4_36_original( query_states, key_states, value_states ) else: - if should_use_mm_int4_qkv(self, device): + if should_use_xetla_mm_qkv(self, device): if not hasattr(self, "qkv_proj_qweight"): - self.qkv_proj_qweight = fuse_qkv_weight(self.q_proj, - self.k_proj, - self.v_proj) + self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj, + self.k_proj, + self.v_proj, + self.q_proj.weight.qtype,) import linear_q4_0 - qkv_states = linear_q4_0.mm_int4(hidden_states, self.qkv_proj_qweight) - query_states = qkv_states[:, :, :hidden_size] - key_states = qkv_states[:, :, hidden_size:2*hidden_size] - value_states = qkv_states[:, :, 2*hidden_size:] + q_out_len = self.q_proj.out_len + k_out_len = self.k_proj.out_len + v_out_len = self.v_proj.out_len + qkv_states = linear_q4_0.mm_xetla(hidden_states, + self.qkv_proj_qweight, + self.q_proj.weight.qtype) + query_states = qkv_states[:, :, :q_out_len] + key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] + value_states = qkv_states[:, :, q_out_len + k_out_len:] else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 818b98f3..e57ba7f6 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -54,6 +54,8 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check +from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv +from ipex_llm.transformers.models.llama import fuse_qkv_weight_xetla try: from transformers.cache_utils import Cache except ImportError: @@ -84,7 +86,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids): def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): return llama_decoding_fast_path_qtype_check(proj) and \ - use_fuse_rope and enough_kv_room and bs == 1 + use_fuse_rope and enough_kv_room and bs == 1 and \ + not proj.enable_xetla def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len, @@ -382,7 +385,6 @@ def mistral_attention_forward_original( use_fuse_rope, enough_kv_room, bsz * q_len) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla if decoding_fast_path: hidden_states = hidden_states.view(1, -1) @@ -402,9 +404,27 @@ def mistral_attention_forward_original( self.head_dim) kv_seq_len += 1 else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + + if should_use_xetla_mm_qkv(self, device): + if not hasattr(self, "qkv_proj_qweight"): + self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj, + self.k_proj, + self.v_proj, + self.q_proj.qtype) + import linear_q4_0 + q_out_len = self.q_proj.out_len + k_out_len = self.k_proj.out_len + v_out_len = self.v_proj.out_len + qkv_states = linear_q4_0.mm_xetla(hidden_states, + self.qkv_proj_qweight, + self.q_proj.qtype) + query_states = qkv_states[:, :, :q_out_len] + key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] + value_states = qkv_states[:, :, q_out_len + k_out_len:] + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, @@ -769,9 +789,26 @@ def mistral_attention_forward_4_36_original( past_key_value.value_cache[self.layer_idx] = value_states else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if should_use_xetla_mm_qkv(self, device): + if not hasattr(self, "qkv_proj_qweight"): + self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj, + self.k_proj, + self.v_proj, + self.q_proj.qtype) + import linear_q4_0 + q_out_len = self.q_proj.out_len + k_out_len = self.k_proj.out_len + v_out_len = self.v_proj.out_len + qkv_states = linear_q4_0.mm_xetla(hidden_states, + self.qkv_proj_qweight, + self.q_proj.qtype) + query_states = qkv_states[:, :, :q_out_len] + key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] + value_states = qkv_states[:, :, q_out_len + k_out_len:] + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 09709136..88ab88f3 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -142,6 +142,7 @@ def qwen_attention_forward_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states) qtype_check = decoding_fast_path_qtype_check(self.q_proj) decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) + decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k, cache_v = layer_past[0], layer_past[1] diff --git a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py index 34f79052..4e33423b 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py @@ -91,6 +91,7 @@ def qwen_attention_forward_vl( use_fuse_rope = should_use_fuse_rope(self, hidden_states) qtype_check = decoding_fast_path_qtype_check(self.q_proj) decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) + decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k, cache_v = layer_past[0], layer_past[1] diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index 71f4d817..b80ab209 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -43,7 +43,8 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): return decoding_fast_path_qtype_check(proj) and \ - use_fuse_rope and enough_kv_room and bs == 1 + use_fuse_rope and enough_kv_room and bs == 1 \ + and not proj.enable_xetla def should_use_fuse_rope(self, hidden_states, position_ids):