diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 134ca1a9..b25964fe 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -422,8 +422,7 @@ def baichuan_attention_forward_13b_quantized( attn_output = torch.matmul(attn_weights, value_states) else: import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, - value_states.transpose(-1, -2)) + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index 51c16608..a5848e6e 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -395,8 +395,7 @@ def baichuan_attention_forward_13b_quantized( attn_output = torch.matmul(attn_weights, value_states) else: import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, - value_states.transpose(-1, -2)) + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index 466f2706..c7b242c7 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -405,8 +405,7 @@ def qwen2moe_attention_forward_quantized( if q_len == 1 and query_states.device.type == 'xpu' and not self.training \ and not hidden_states.requires_grad: import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, - value_states.transpose(-1, -2)) + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states) else: attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 4c8a6904..b91b40a9 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -447,8 +447,7 @@ def stablelm_attention_forward_quantized( attn_output = torch.matmul(attn_weights, value_states) else: import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, - value_states.transpose(-1, -2)) + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) invalidInputError(attn_output.size() == attn_output_size, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 1494a657..3691886b 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -104,18 +104,12 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), k_cache_storage.stride(), storage_offset=0) - if new_layout: - v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, - dtype=torch.uint8, device=device) - v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), - v_cache_storage.stride(), storage_offset=0) - return k_cache, v_cache - else: - v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length, - dtype=torch.uint8, device=device) - v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0), - v_cache_storage.stride(), storage_offset=0) - return k_cache, v_cache.transpose(-1, -2) + # ignore `new_layout`, will remove it in next PR + v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, + dtype=torch.uint8, device=device) + v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), + v_cache_storage.stride(), storage_offset=0) + return k_cache, v_cache def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False): @@ -134,23 +128,22 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False): new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0) new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0) - fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2] - new_k_cache[:, :, cur_length:new_length, :] = fp8_key - fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2] - new_v_cache[:, :, cur_length:new_length, :] = fp8_value + import linear_q4_0 + linear_q4_0.quantize_key_value(key, value, + new_k_cache[:, :, cur_length:new_length, :], + new_v_cache[:, :, cur_length:new_length, :]) return new_k_cache, new_v_cache def restore_fp8_kv_cache(k_cache, v_cache, dtype): - new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device) - new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache - new_k_cache = new_k_cache.view(torch.half) - new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device) - new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache - new_v_cache = new_v_cache.view(torch.half) + key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype) + value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype) - return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype) + import linear_q4_0 + linear_q4_0.dequantize_key_value(k_cache, v_cache, key_states, value_states) + + return key_states, value_states def rotate_half(x): diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index ad753710..6e0674ef 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -293,8 +293,7 @@ def yuan_attention_forward_quantized( attn_output = torch.matmul(attn_weights, value_states) else: import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, - value_states.transpose(-1, -2)) + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states) invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), "`attn_output` should be of size "