add sdp fp8 for qwen llama436 baichuan mistral baichuan2 (#10485)

* add sdp fp8

* fix style

* fix qwen

* fix baichuan 13

* revert baichuan 13b and baichuan2-13b

* fix style

* update
This commit is contained in:
Xin Qiu 2024-03-21 17:23:05 +08:00 committed by GitHub
parent 30f111cd32
commit dba7ddaab3
5 changed files with 137 additions and 160 deletions

View file

@ -126,53 +126,49 @@ def baichuan_attention_forward_7b_quantized(
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device
device=device, new_layout=True
)
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
key_states, value_states, new_layout=True)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}"
)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}"
)
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
invalidInputError(
attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),

View file

@ -143,12 +143,12 @@ def baichuan_attention_forward_7b_quantized(
kv_seq_len = key_states.shape[-2]
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device
device=device, new_layout=True
)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
key_states, value_states, new_layout=True)
past_key_value = (key_states, value_states) if use_cache else None
@ -161,20 +161,17 @@ def baichuan_attention_forward_7b_quantized(
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
else:
import linear_q4_0
attn_output = linear_q4_0.query_key_fp8_matmul(query_states * scaling_factor, key_states)
if attention_mask is not None:
attn_output += attention_mask
attn_output = torch.softmax(attn_output, -1)
attn_output = attn_output.to(hidden_states.dtype)
if query_states.size(2) != 1 or device.type != 'xpu':
if attention_mask is not None:
attn_output += attention_mask
attn_output = torch.softmax(attn_output, -1)
attn_output = attn_output.to(hidden_states.dtype)
attn_output = torch.matmul(attn_output, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_output,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)

View file

@ -1001,11 +1001,13 @@ def llama_attention_forward_4_36_quantized(
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
self.layer_idx, cache_kwargs,
new_layout=True)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
self.layer_idx, cache_kwargs,
new_layout=True)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
@ -1015,35 +1017,32 @@ def llama_attention_forward_4_36_quantized(
value_states = repeat_kv(value_states, self.num_key_value_groups)\
.to(device, dtype=query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
f" but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
f"Attention weights should be of size"
f" {(bsz, self.num_heads, q_len, kv_seq_len)},"
f" but is {attn_weights.size()}"
)
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(

View file

@ -295,7 +295,7 @@ def mistral_attention_forward_quantized(
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=query_states.device
device=query_states.device, new_layout=True
)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
@ -303,7 +303,7 @@ def mistral_attention_forward_quantized(
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
key_states, value_states, new_layout=True)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)
@ -311,38 +311,35 @@ def mistral_attention_forward_quantized(
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
@ -658,48 +655,47 @@ def mistral_attention_forward_4_36_quantized(
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
self.layer_idx, cache_kwargs,
new_layout=True)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
self.layer_idx, cache_kwargs,
new_layout=True)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:

View file

@ -439,34 +439,26 @@ def qwen_attention_forward_quantized(
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
device=query.device,
device=query.device, new_layout=True
)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
else:
if decoding_fast_path:
k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
else:
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True)
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
@ -489,44 +481,41 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
key, value = restore_fp8_kv_cache(key, value, query.dtype)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query, key)
if self.scale_attn_weights:
if self.use_cache_quantization:
size_temp = value[0].size(-1)
if self.scale_attn_weights:
if self.use_cache_quantization:
size_temp = value[0].size(-1)
else:
size_temp = value.size(-1)
attn_weights = attn_weights / (size_temp ** 0.5)
mask_value = torch.finfo(attn_weights.dtype).min
if causal_mask is not None:
attn_weights = torch.where(
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if self.softmax_in_fp32:
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
else:
size_temp = value.size(-1)
attn_weights = attn_weights / (size_temp ** 0.5)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
mask_value = torch.finfo(attn_weights.dtype).min
if causal_mask is not None:
attn_weights = torch.where(
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
)
attn_weights = attn_weights.type(query.dtype)
attn_weights = self.attn_dropout(attn_weights)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if head_mask is not None:
attn_weights = attn_weights * head_mask
if self.softmax_in_fp32:
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(query.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
if query.size(2) != 1 or query.device.type != 'xpu':
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
attn_output = torch.matmul(attn_weights, value)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2))
attn_output = linear_q4_0.sdp_fp8(query, key, value,
attention_mask)
attn_weights = None
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights