use new quantize kv (#10888)
This commit is contained in:
		
							parent
							
								
									751f6d11d8
								
							
						
					
					
						commit
						46ba962168
					
				
					 6 changed files with 21 additions and 33 deletions
				
			
		| 
						 | 
				
			
			@ -422,8 +422,7 @@ def baichuan_attention_forward_13b_quantized(
 | 
			
		|||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        else:
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
 | 
			
		||||
                                                            value_states.transpose(-1, -2))
 | 
			
		||||
            attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -395,8 +395,7 @@ def baichuan_attention_forward_13b_quantized(
 | 
			
		|||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    else:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
 | 
			
		||||
                                                        value_states.transpose(-1, -2))
 | 
			
		||||
        attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -405,8 +405,7 @@ def qwen2moe_attention_forward_quantized(
 | 
			
		|||
    if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
 | 
			
		||||
            and not hidden_states.requires_grad:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
 | 
			
		||||
                                                        value_states.transpose(-1, -2))
 | 
			
		||||
        attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -447,8 +447,7 @@ def stablelm_attention_forward_quantized(
 | 
			
		|||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        else:
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
 | 
			
		||||
                                                            value_states.transpose(-1, -2))
 | 
			
		||||
            attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
			
		||||
    invalidInputError(attn_output.size() == attn_output_size,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -104,18 +104,12 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
 | 
			
		|||
    k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
 | 
			
		||||
                                         k_cache_storage.stride(), storage_offset=0)
 | 
			
		||||
 | 
			
		||||
    if new_layout:
 | 
			
		||||
    # ignore `new_layout`, will remove it in next PR
 | 
			
		||||
    v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
 | 
			
		||||
                                  dtype=torch.uint8, device=device)
 | 
			
		||||
    v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
 | 
			
		||||
                                         v_cache_storage.stride(), storage_offset=0)
 | 
			
		||||
    return k_cache, v_cache
 | 
			
		||||
    else:
 | 
			
		||||
        v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
 | 
			
		||||
                                      dtype=torch.uint8, device=device)
 | 
			
		||||
        v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
 | 
			
		||||
                                             v_cache_storage.stride(), storage_offset=0)
 | 
			
		||||
        return k_cache, v_cache.transpose(-1, -2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
 | 
			
		||||
| 
						 | 
				
			
			@ -134,23 +128,22 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
 | 
			
		|||
        new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
 | 
			
		||||
        new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
 | 
			
		||||
 | 
			
		||||
    fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
 | 
			
		||||
    new_k_cache[:, :, cur_length:new_length, :] = fp8_key
 | 
			
		||||
    fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2]
 | 
			
		||||
    new_v_cache[:, :, cur_length:new_length, :] = fp8_value
 | 
			
		||||
    import linear_q4_0
 | 
			
		||||
    linear_q4_0.quantize_key_value(key, value,
 | 
			
		||||
                                   new_k_cache[:, :, cur_length:new_length, :],
 | 
			
		||||
                                   new_v_cache[:, :, cur_length:new_length, :])
 | 
			
		||||
 | 
			
		||||
    return new_k_cache, new_v_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
 | 
			
		||||
    new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device)
 | 
			
		||||
    new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache
 | 
			
		||||
    new_k_cache = new_k_cache.view(torch.half)
 | 
			
		||||
    new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device)
 | 
			
		||||
    new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache
 | 
			
		||||
    new_v_cache = new_v_cache.view(torch.half)
 | 
			
		||||
    key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
 | 
			
		||||
    value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype)
 | 
			
		||||
    import linear_q4_0
 | 
			
		||||
    linear_q4_0.dequantize_key_value(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
 | 
			
		||||
    return key_states, value_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rotate_half(x):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -293,8 +293,7 @@ def yuan_attention_forward_quantized(
 | 
			
		|||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        else:
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
 | 
			
		||||
                                                            value_states.transpose(-1, -2))
 | 
			
		||||
            attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
        invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
			
		||||
                          "`attn_output` should be of size "
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue