From bd64488b2ac0042da3bc3cae9ec91e6ac260bbcc Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 15 Mar 2024 17:36:52 +0800 Subject: [PATCH] add mask support for llama/chatglm fp8 sdp (#10433) * add mask support for fp8 sdp * fix chatglm2 dtype * update --- .../bigdl/llm/transformers/models/chatglm2.py | 25 +++++++++++-------- .../bigdl/llm/transformers/models/llama.py | 3 ++- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index a919ee51..1db8424a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -97,7 +97,7 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten def chatglm_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 - x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).to(self.weight.dtype).contiguous() + x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps) if 1 < x_2d.size(0) <= 64: # may use XMX, need copy output = output.clone() @@ -254,6 +254,7 @@ def chatglm2_quantized_attention_forward_8eb45c( context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True) else: context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask) + context_layer = context_layer.to(query_layer.dtype) if use_cache: k_cache, v_cache = init_fp8_kv_cache(batch_size, @@ -272,25 +273,29 @@ def chatglm2_quantized_attention_forward_8eb45c( k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer, new_layout=True) + if attention_mask is not None: + attention_mask = ~attention_mask + attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, + device=query_layer.device) + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask + else: + attn_bias = None + if seq_len != 1: key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype) key, value = repeat_kv(key, value, n_head) attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim) - if attention_mask is not None: - attention_mask = ~attention_mask - attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, - device=query_layer.device) - if attention_mask.dtype == torch.bool: - attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) - else: - attn_bias += attention_mask + if attn_bias is not None: attn += attn_bias attn = F.softmax(attn, dim=-1, dtype=torch.float32) context_layer = torch.matmul(attn.to(value.dtype), value) else: key, value = k_cache, v_cache import linear_q4_0 - context_layer = linear_q4_0.sdp_fp8(query_layer, key, value) + context_layer = linear_q4_0.sdp_fp8(query_layer, key, value, attn_bias) # context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim] context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 378a7cb5..b6c3e9cc 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -418,7 +418,8 @@ def llama_attention_forward_4_31_quantized( self.head_dim, self.num_heads) else: import linear_q4_0 - attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states) + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, + attention_mask) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous()