add sdp fp8 for qwen llama436 baichuan mistral baichuan2 (#10485)

* add sdp fp8

* fix style

* fix qwen

* fix baichuan 13

* revert baichuan 13b and baichuan2-13b

* fix style

* update
This commit is contained in:
Xin Qiu 2024-03-21 17:23:05 +08:00 committed by GitHub
parent 30f111cd32
commit dba7ddaab3
5 changed files with 137 additions and 160 deletions

View file

@ -126,24 +126,20 @@ def baichuan_attention_forward_7b_quantized(
if use_cache: if use_cache:
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim, bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device device=device, new_layout=True
) )
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states) key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
else: else:
k_cache, v_cache = past_key_value k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states) key_states, value_states, new_layout=True)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
if query_states.size(2) != 1 or query_states.device.type != 'xpu': if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype) query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim) attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@ -167,12 +163,12 @@ def baichuan_attention_forward_7b_quantized(
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
else: else:
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
value_states.transpose(-1, -2)) attention_mask)
attn_weights = None
invalidInputError( invalidInputError(
attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),

View file

@ -143,12 +143,12 @@ def baichuan_attention_forward_7b_quantized(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim, bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device device=device, new_layout=True
) )
else: else:
k_cache, v_cache = past_key_value k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states) key_states, value_states, new_layout=True)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
@ -161,20 +161,17 @@ def baichuan_attention_forward_7b_quantized(
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype) query_states.dtype)
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1)) attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
else:
import linear_q4_0
attn_output = linear_q4_0.query_key_fp8_matmul(query_states * scaling_factor, key_states)
if attention_mask is not None: if attention_mask is not None:
attn_output += attention_mask attn_output += attention_mask
attn_output = torch.softmax(attn_output, -1) attn_output = torch.softmax(attn_output, -1)
attn_output = attn_output.to(hidden_states.dtype) attn_output = attn_output.to(hidden_states.dtype)
if query_states.size(2) != 1 or device.type != 'xpu':
attn_output = torch.matmul(attn_output, value_states) attn_output = torch.matmul(attn_output, value_states)
else: else:
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_output, attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
value_states.transpose(-1, -2)) attention_mask)
attn_weights = None
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)

View file

@ -1001,11 +1001,13 @@ def llama_attention_forward_4_36_quantized(
if use_cache: if use_cache:
cache_kwargs = None cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs) self.layer_idx, cache_kwargs,
new_layout=True)
else: else:
cache_kwargs = None # Specific to RoPE models cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs) self.layer_idx, cache_kwargs,
new_layout=True)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu': if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
@ -1015,14 +1017,12 @@ def llama_attention_forward_4_36_quantized(
value_states = repeat_kv(value_states, self.num_key_value_groups)\ value_states = repeat_kv(value_states, self.num_key_value_groups)\
.to(device, dtype=query_states.dtype) .to(device, dtype=query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim) attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError( invalidInputError(
False, False,
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," f"Attention weights should be of size"
f" {(bsz, self.num_heads, q_len, kv_seq_len)},"
f" but is {attn_weights.size()}" f" but is {attn_weights.size()}"
) )
@ -1037,13 +1037,12 @@ def llama_attention_forward_4_36_quantized(
# at inference time, for memory considerations, may not need to upcast attention to fp32 # at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
else: else:
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
value_states.transpose(-1, -2)) attention_mask)
attn_weights = None
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError( invalidInputError(

View file

@ -295,7 +295,7 @@ def mistral_attention_forward_quantized(
if use_cache: if use_cache:
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim, bsz, self.num_heads, kv_seq_len, self.head_dim,
device=query_states.device device=query_states.device, new_layout=True
) )
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states) key_states, value_states)
@ -303,7 +303,7 @@ def mistral_attention_forward_quantized(
else: else:
k_cache, v_cache = past_key_value k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states) key_states, value_states, new_layout=True)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
@ -311,9 +311,6 @@ def mistral_attention_forward_quantized(
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype) query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim) attn_weights = attn_weights / math.sqrt(self.head_dim)
@ -337,12 +334,12 @@ def mistral_attention_forward_quantized(
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
else: else:
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
value_states.transpose(-1, -2)) attention_mask)
attn_weights = None
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size: if attn_output.size() != attn_output_size:
@ -658,19 +655,18 @@ def mistral_attention_forward_4_36_quantized(
if use_cache: if use_cache:
cache_kwargs = None cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs) self.layer_idx, cache_kwargs,
new_layout=True)
else: else:
cache_kwargs = None # Specific to RoPE models cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs) self.layer_idx, cache_kwargs,
new_layout=True)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu': if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype) query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim) attn_weights = attn_weights / math.sqrt(self.head_dim)
@ -694,12 +690,12 @@ def mistral_attention_forward_4_36_quantized(
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
else: else:
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
value_states.transpose(-1, -2)) attention_mask)
attn_weights = None
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size: if attn_output.size() != attn_output_size:

View file

@ -439,30 +439,22 @@ def qwen_attention_forward_quantized(
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, kv_seq_len, self.head_dim, query.size(0), self.num_heads, kv_seq_len, self.head_dim,
device=query.device, device=query.device, new_layout=True
) )
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
else: else:
if decoding_fast_path: if decoding_fast_path:
k_cache, v_cache = layer_past[0], layer_past[1] k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
else: else:
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
k_cache, v_cache = layer_past[0], layer_past[1] k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2) k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2) v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True)
attn_output, attn_weight = core_attn( attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask self, query, key, value, causal_mask, attention_mask, head_mask
@ -489,9 +481,6 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32 # We have no CPU fp8 matmul implementation for now, so just upscale to fp32
key, value = restore_fp8_kv_cache(key, value, query.dtype) key, value = restore_fp8_kv_cache(key, value, query.dtype)
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query, key)
if self.scale_attn_weights: if self.scale_attn_weights:
if self.use_cache_quantization: if self.use_cache_quantization:
@ -520,13 +509,13 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
if head_mask is not None: if head_mask is not None:
attn_weights = attn_weights * head_mask attn_weights = attn_weights * head_mask
if query.size(2) != 1 or query.device.type != 'xpu':
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32 # We have no CPU fp8 matmul implementation for now, so just upscale to fp32
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
else: else:
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2)) attn_output = linear_q4_0.sdp_fp8(query, key, value,
attention_mask)
attn_weights = None
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights return attn_output, attn_weights