use new quantize kv (#10888)

This commit is contained in:
Yishuo Wang 2024-04-26 14:42:17 +08:00 committed by GitHub
parent 751f6d11d8
commit 46ba962168
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 21 additions and 33 deletions

View file

@ -422,8 +422,7 @@ def baichuan_attention_forward_13b_quantized(
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View file

@ -395,8 +395,7 @@ def baichuan_attention_forward_13b_quantized(
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View file

@ -405,8 +405,7 @@ def qwen2moe_attention_forward_quantized(
if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
and not hidden_states.requires_grad:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
else:
attn_output = torch.matmul(attn_weights, value_states)

View file

@ -447,8 +447,7 @@ def stablelm_attention_forward_quantized(
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
invalidInputError(attn_output.size() == attn_output_size,

View file

@ -104,18 +104,12 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
k_cache_storage.stride(), storage_offset=0)
if new_layout:
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache
else:
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache.transpose(-1, -2)
# ignore `new_layout`, will remove it in next PR
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
@ -134,23 +128,22 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2]
new_v_cache[:, :, cur_length:new_length, :] = fp8_value
import linear_q4_0
linear_q4_0.quantize_key_value(key, value,
new_k_cache[:, :, cur_length:new_length, :],
new_v_cache[:, :, cur_length:new_length, :])
return new_k_cache, new_v_cache
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device)
new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache
new_k_cache = new_k_cache.view(torch.half)
new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device)
new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache
new_v_cache = new_v_cache.view(torch.half)
key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype)
import linear_q4_0
linear_q4_0.dequantize_key_value(k_cache, v_cache, key_states, value_states)
return key_states, value_states
def rotate_half(x):

View file

@ -293,8 +293,7 @@ def yuan_attention_forward_quantized(
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
"`attn_output` should be of size "