add mask support for llama/chatglm fp8 sdp (#10433)

* add mask support for fp8 sdp

* fix chatglm2 dtype

* update
This commit is contained in:
Yishuo Wang 2024-03-15 17:36:52 +08:00 committed by GitHub
parent 444b11af22
commit bd64488b2a
2 changed files with 17 additions and 11 deletions

View file

@ -97,7 +97,7 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten
def chatglm_rms_norm_forward(self, hidden_states): def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0 import linear_q4_0
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).to(self.weight.dtype).contiguous() x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps) output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone() output = output.clone()
@ -254,6 +254,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True) context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True)
else: else:
context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask) context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask)
context_layer = context_layer.to(query_layer.dtype)
if use_cache: if use_cache:
k_cache, v_cache = init_fp8_kv_cache(batch_size, k_cache, v_cache = init_fp8_kv_cache(batch_size,
@ -272,10 +273,6 @@ def chatglm2_quantized_attention_forward_8eb45c(
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer, k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer,
new_layout=True) new_layout=True)
if seq_len != 1:
key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype)
key, value = repeat_kv(key, value, n_head)
attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim)
if attention_mask is not None: if attention_mask is not None:
attention_mask = ~attention_mask attention_mask = ~attention_mask
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
@ -284,13 +281,21 @@ def chatglm2_quantized_attention_forward_8eb45c(
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else: else:
attn_bias += attention_mask attn_bias += attention_mask
else:
attn_bias = None
if seq_len != 1:
key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype)
key, value = repeat_kv(key, value, n_head)
attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim)
if attn_bias is not None:
attn += attn_bias attn += attn_bias
attn = F.softmax(attn, dim=-1, dtype=torch.float32) attn = F.softmax(attn, dim=-1, dtype=torch.float32)
context_layer = torch.matmul(attn.to(value.dtype), value) context_layer = torch.matmul(attn.to(value.dtype), value)
else: else:
key, value = k_cache, v_cache key, value = k_cache, v_cache
import linear_q4_0 import linear_q4_0
context_layer = linear_q4_0.sdp_fp8(query_layer, key, value) context_layer = linear_q4_0.sdp_fp8(query_layer, key, value, attn_bias)
# context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim] # context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1) context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1)

View file

@ -418,7 +418,8 @@ def llama_attention_forward_4_31_quantized(
self.head_dim, self.num_heads) self.head_dim, self.num_heads)
else: else:
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states) attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()