use new quantize kv (#10888)
This commit is contained in:
parent
751f6d11d8
commit
46ba962168
6 changed files with 21 additions and 33 deletions
|
|
@ -422,8 +422,7 @@ def baichuan_attention_forward_13b_quantized(
|
|||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
|
|
|||
|
|
@ -395,8 +395,7 @@ def baichuan_attention_forward_13b_quantized(
|
|||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
|
|
|||
|
|
@ -405,8 +405,7 @@ def qwen2moe_attention_forward_quantized(
|
|||
if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
|
||||
and not hidden_states.requires_grad:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
|
||||
else:
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
|
|
|
|||
|
|
@ -447,8 +447,7 @@ def stablelm_attention_forward_quantized(
|
|||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
|
||||
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
invalidInputError(attn_output.size() == attn_output_size,
|
||||
|
|
|
|||
|
|
@ -104,18 +104,12 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
|
|||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||
k_cache_storage.stride(), storage_offset=0)
|
||||
|
||||
if new_layout:
|
||||
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||
dtype=torch.uint8, device=device)
|
||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||
v_cache_storage.stride(), storage_offset=0)
|
||||
return k_cache, v_cache
|
||||
else:
|
||||
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
||||
dtype=torch.uint8, device=device)
|
||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
|
||||
v_cache_storage.stride(), storage_offset=0)
|
||||
return k_cache, v_cache.transpose(-1, -2)
|
||||
# ignore `new_layout`, will remove it in next PR
|
||||
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||
dtype=torch.uint8, device=device)
|
||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||
v_cache_storage.stride(), storage_offset=0)
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
|
||||
|
|
@ -134,23 +128,22 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
|
|||
new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
|
||||
new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
|
||||
|
||||
fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
|
||||
new_k_cache[:, :, cur_length:new_length, :] = fp8_key
|
||||
fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2]
|
||||
new_v_cache[:, :, cur_length:new_length, :] = fp8_value
|
||||
import linear_q4_0
|
||||
linear_q4_0.quantize_key_value(key, value,
|
||||
new_k_cache[:, :, cur_length:new_length, :],
|
||||
new_v_cache[:, :, cur_length:new_length, :])
|
||||
|
||||
return new_k_cache, new_v_cache
|
||||
|
||||
|
||||
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
|
||||
new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device)
|
||||
new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache
|
||||
new_k_cache = new_k_cache.view(torch.half)
|
||||
new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device)
|
||||
new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache
|
||||
new_v_cache = new_v_cache.view(torch.half)
|
||||
key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
|
||||
value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
|
||||
|
||||
return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype)
|
||||
import linear_q4_0
|
||||
linear_q4_0.dequantize_key_value(k_cache, v_cache, key_states, value_states)
|
||||
|
||||
return key_states, value_states
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
|
|
|
|||
|
|
@ -293,8 +293,7 @@ def yuan_attention_forward_quantized(
|
|||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
|
||||
|
||||
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||
"`attn_output` should be of size "
|
||||
|
|
|
|||
Loading…
Reference in a new issue