add sdp fp8 for qwen llama436 baichuan mistral baichuan2 (#10485)
* add sdp fp8 * fix style * fix qwen * fix baichuan 13 * revert baichuan 13b and baichuan2-13b * fix style * update
This commit is contained in:
parent
30f111cd32
commit
dba7ddaab3
5 changed files with 137 additions and 160 deletions
|
|
@ -126,53 +126,49 @@ def baichuan_attention_forward_7b_quantized(
|
|||
if use_cache:
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=device
|
||||
device=device, new_layout=True
|
||||
)
|
||||
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||
past_key_value = (key_states, value_states)
|
||||
else:
|
||||
k_cache, v_cache = past_key_value
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
key_states, value_states, new_layout=True)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = (key_states, value_states)
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
|
||||
f", but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
|
||||
f", but is {attn_weights.size()}"
|
||||
)
|
||||
if attention_mask is not None:
|
||||
invalidInputError(
|
||||
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
||||
f"but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights,
|
||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
if attention_mask is not None:
|
||||
invalidInputError(
|
||||
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
||||
f"but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights,
|
||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attn_weights = None
|
||||
|
||||
invalidInputError(
|
||||
attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||
|
|
|
|||
|
|
@ -143,12 +143,12 @@ def baichuan_attention_forward_7b_quantized(
|
|||
kv_seq_len = key_states.shape[-2]
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=device
|
||||
device=device, new_layout=True
|
||||
)
|
||||
else:
|
||||
k_cache, v_cache = past_key_value
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
key_states, value_states, new_layout=True)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
|
|
@ -161,20 +161,17 @@ def baichuan_attention_forward_7b_quantized(
|
|||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.query_key_fp8_matmul(query_states * scaling_factor, key_states)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_output += attention_mask
|
||||
attn_output = torch.softmax(attn_output, -1)
|
||||
attn_output = attn_output.to(hidden_states.dtype)
|
||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||
if attention_mask is not None:
|
||||
attn_output += attention_mask
|
||||
attn_output = torch.softmax(attn_output, -1)
|
||||
attn_output = attn_output.to(hidden_states.dtype)
|
||||
attn_output = torch.matmul(attn_output, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_output,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
|
|
|
|||
|
|
@ -1001,11 +1001,13 @@ def llama_attention_forward_4_36_quantized(
|
|||
if use_cache:
|
||||
cache_kwargs = None
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
self.layer_idx, cache_kwargs,
|
||||
new_layout=True)
|
||||
else:
|
||||
cache_kwargs = None # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
self.layer_idx, cache_kwargs,
|
||||
new_layout=True)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
|
|
@ -1015,35 +1017,32 @@ def llama_attention_forward_4_36_quantized(
|
|||
value_states = repeat_kv(value_states, self.num_key_value_groups)\
|
||||
.to(device, dtype=query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||
f" but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
f"Attention weights should be of size"
|
||||
f" {(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||
f" but is {attn_weights.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attn_weights = None
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ def mistral_attention_forward_quantized(
|
|||
if use_cache:
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=query_states.device
|
||||
device=query_states.device, new_layout=True
|
||||
)
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
|
|
@ -303,7 +303,7 @@ def mistral_attention_forward_quantized(
|
|||
else:
|
||||
k_cache, v_cache = past_key_value
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
key_states, value_states, new_layout=True)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
|
|
@ -311,38 +311,35 @@ def mistral_attention_forward_quantized(
|
|||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attn_weights = None
|
||||
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
if attn_output.size() != attn_output_size:
|
||||
|
|
@ -658,48 +655,47 @@ def mistral_attention_forward_4_36_quantized(
|
|||
if use_cache:
|
||||
cache_kwargs = None
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
self.layer_idx, cache_kwargs,
|
||||
new_layout=True)
|
||||
else:
|
||||
cache_kwargs = None # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
self.layer_idx, cache_kwargs,
|
||||
new_layout=True)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attn_weights = None
|
||||
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
if attn_output.size() != attn_output_size:
|
||||
|
|
|
|||
|
|
@ -439,34 +439,26 @@ def qwen_attention_forward_quantized(
|
|||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=query.device,
|
||||
device=query.device, new_layout=True
|
||||
)
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||
else:
|
||||
if decoding_fast_path:
|
||||
k_cache, v_cache = layer_past[0], layer_past[1]
|
||||
k_cache = k_cache.transpose(1, 2)
|
||||
v_cache = v_cache.transpose(1, 2)
|
||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||
|
||||
attn_output, attn_weight = core_attn(
|
||||
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
||||
else:
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
k_cache, v_cache = layer_past[0], layer_past[1]
|
||||
k_cache = k_cache.transpose(1, 2)
|
||||
v_cache = v_cache.transpose(1, 2)
|
||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||
k_cache = k_cache.transpose(1, 2)
|
||||
v_cache = v_cache.transpose(1, 2)
|
||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||
|
||||
attn_output, attn_weight = core_attn(
|
||||
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True)
|
||||
|
||||
attn_output, attn_weight = core_attn(
|
||||
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
||||
context_layer = self._merge_heads(
|
||||
attn_output, self.num_heads, self.head_dim
|
||||
|
|
@ -489,44 +481,41 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
|
|||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||
key, value = restore_fp8_kv_cache(key, value, query.dtype)
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query, key)
|
||||
|
||||
if self.scale_attn_weights:
|
||||
if self.use_cache_quantization:
|
||||
size_temp = value[0].size(-1)
|
||||
if self.scale_attn_weights:
|
||||
if self.use_cache_quantization:
|
||||
size_temp = value[0].size(-1)
|
||||
else:
|
||||
size_temp = value.size(-1)
|
||||
attn_weights = attn_weights / (size_temp ** 0.5)
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
if causal_mask is not None:
|
||||
attn_weights = torch.where(
|
||||
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if self.softmax_in_fp32:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
|
||||
else:
|
||||
size_temp = value.size(-1)
|
||||
attn_weights = attn_weights / (size_temp ** 0.5)
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
if causal_mask is not None:
|
||||
attn_weights = torch.where(
|
||||
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
||||
)
|
||||
attn_weights = attn_weights.type(query.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
if self.softmax_in_fp32:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
|
||||
else:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = attn_weights.type(query.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
if query.size(2) != 1 or query.device.type != 'xpu':
|
||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2))
|
||||
|
||||
attn_output = linear_q4_0.sdp_fp8(query, key, value,
|
||||
attention_mask)
|
||||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
|
|
|||
Loading…
Reference in a new issue