From 8d7326ae03768e4339c43374a58f337290e035c2 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Tue, 16 Jan 2024 11:29:13 +0800 Subject: [PATCH] LLM: fix chatglm3 sdp to support speculative decoding (#9900) * fix chatglm3 * fix * update * meet code review * fix --- .../src/bigdl/llm/transformers/models/chatglm2.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index f29a0f49..60213df8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -373,11 +373,15 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio key_layer, value_layer, is_causal=True) - elif attention_mask is None: + else: head_dim = query_layer.size(-1) attn = torch.matmul(query_layer, key_layer.transpose(2, 3)) / math.sqrt(head_dim) - if L == S: + if attention_mask is not None: + attention_mask = ~attention_mask + attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) + attn += attention_mask + elif L == S: # first token, need attention mask attn_bias = torch.zeros(L, S, dtype=query_layer.dtype, device=query_layer.device) @@ -388,12 +392,6 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio attn += attn_bias attn = torch.softmax(attn, -1) context_layer = torch.matmul(attn, value_layer) - else: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape)