From 592f7aa61ee7c9c6829a0dbf2289b8718495a70c Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Wed, 12 Jun 2024 17:11:56 +0800 Subject: [PATCH] Refine glm1-4 sdp (#11276) * chatglm * update * update * change chatglm * update sdpa * update * fix style * fix * fix glm * update glm2-32k * update glm2-32k * fix cpu * update * change lower_bound --- .../ipex_llm/transformers/models/chatglm.py | 200 +++++----------- .../ipex_llm/transformers/models/chatglm2.py | 221 ++++++------------ .../transformers/models/chatglm2_32k.py | 8 +- .../ipex_llm/transformers/models/chatglm4.py | 47 ++-- .../test_transformers_api_attention.py | 2 +- 5 files changed, 144 insertions(+), 334 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm.py b/python/llm/src/ipex_llm/transformers/models/chatglm.py index 9b9888f4..51a2026a 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint import torch.nn.functional as F from typing import Optional, Tuple from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from ipex_llm.transformers.models.chatglm2 import glm_sdpa def rotate_half(x): @@ -103,156 +104,63 @@ def attention_fn( else: present = None - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2: - query_layer = query_layer.permute(1, 2, 0, 3) - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + query_layer = query_layer.permute(1, 2, 0, 3) + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - if torch.is_autocast_cpu_enabled(): - attention_mask = torch.ones(query_layer.shape[2], - key_layer.shape[2], - dtype=torch.bool).tril(diagonal=0) - attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) - attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) - query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) - key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) - value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask, - is_causal=False) - else: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask, - is_causal=True) + if torch.is_autocast_cpu_enabled(): + attention_mask = torch.ones(query_layer.shape[2], + key_layer.shape[2], + dtype=torch.bool).tril(diagonal=0) + attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) + query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) + key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) + value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask, + is_causal=False) else: - # attention_mask is not None only when past_key_value is not None and q_len > 1 - if attention_mask is not None: - attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, - device=query_layer.device) - attention_mask = ~attention_mask - if attention_mask.dtype == torch.bool: - attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) - else: - attn_bias += attention_mask - else: - attn_bias = None - if torch.is_autocast_cpu_enabled(): - query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) - key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) - value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) - attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask) - else: - head_dim = query_layer.size(-1) - attn = torch.matmul(query_layer.to(key_layer.dtype), - key_layer.transpose(2, 3)) / math.sqrt(head_dim) - if attn_bias is not None: - attn += attn_bias - attn = F.softmax(attn, dim=-1, - dtype=torch.float32).to(value_layer.dtype) - context_layer = torch.matmul(attn, value_layer) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - attention_probs = None - + context_layer = glm_sdpa(query_layer, + key_layer, + value_layer, + attention_mask, + is_causal=True) else: - query_key_layer_scaling_coeff = float(layer_id + 1) - if scaling_attention_score: - query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), - query_layer.size(0), key_layer.size(2)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], - output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=1.0, - out=matmul_result) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - if self.scale_mask_softmax: - self.scale_mask_softmax.scale = query_key_layer_scaling_coeff - attention_probs = self.scale_mask_softmax(attention_scores, - attention_mask.contiguous()) + # attention_mask is not None only when past_key_value is not None and q_len > 1 + if attention_mask is not None: + attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, + device=query_layer.device) + attention_mask = ~attention_mask + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask else: - if not (attention_mask == 0).all(): - # if auto-regressive, skip - attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.dtype - attention_scores = attention_scores.float() - attention_scores = attention_scores * query_key_layer_scaling_coeff - - attention_probs = F.softmax(attention_scores, dim=-1) - - attention_probs = attention_probs.type(dtype) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(0), value_layer.size(1), - query_layer.size(0), value_layer.size(3)) - - # change view [sk, b * np, hn] - value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.empty( - output_size[0] * output_size[1], - output_size[2], value_layer.size(-1), dtype=value_layer.dtype, - device=query_layer.device) - torch.bmm(attention_probs, value_layer, out=context_layer) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) + attn_bias = None + if torch.is_autocast_cpu_enabled(): + query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) + key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) + value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) + attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask) + else: + head_dim = query_layer.size(-1) + attn = torch.matmul(query_layer.to(key_layer.dtype), + key_layer.transpose(2, 3)) / math.sqrt(head_dim) + if attn_bias is not None: + attn += attn_bias + attn = F.softmax(attn, dim=-1, + dtype=torch.float32).to(value_layer.dtype) + context_layer = torch.matmul(attn, value_layer) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + attention_probs = None outputs = (context_layer, present, attention_probs) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 2078087b..983a6533 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache + restore_fp8_kv_cache, use_quantize_kv_cache, use_flash_attention from ipex_llm.transformers.models.utils import use_sdp @@ -60,6 +60,48 @@ def split_tensor_along_last_dim( return tensor_list +def glm_sdpa(query, key, value, attention_mask=None, is_causal=False): + if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu': + context_layer = F.scaled_dot_product_attention(query.to(key.dtype), + key, + value, + attention_mask, + is_causal=is_causal).to(key.dtype) + else: + # attention_mask is not None only when past_key_value is not None and q_len > 1 + if attention_mask is not None: + attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype, + device=query.device) + attention_mask = ~attention_mask + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask + elif is_causal: + L, S = query.size(-2), key.size(-2) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(key.dtype) + else: + attn_bias = None + if use_sdp(query.shape[2], key.shape[2], + query.shape[-1], query): + import xe_addons + attn_output = xe_addons.sdp(query, key, value, attn_bias) + context_layer = attn_output.view(query.shape) + else: + head_dim = query.size(-1) + attn = torch.matmul(query.to(key.dtype), + key.transpose(2, 3)) / math.sqrt(head_dim) + if attn_bias is not None: + attn += attn_bias + attn = F.softmax(attn, dim=-1, + dtype=torch.float32).to(value.dtype) + context_layer = torch.matmul(attn, value) + return context_layer + + @torch.jit.script def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] @@ -271,19 +313,11 @@ def chatglm2_quantized_attention_forward_8eb45c( value_split = torch.split(value, block_size, dim=1) results = [] for q, k, v in zip(query_split, key_split, value_split): - if attention_mask is None: - result = F.scaled_dot_product_attention(q, k, v, is_causal=True) - else: - result = F.scaled_dot_product_attention(q, k, v, attention_mask) + result = glm_sdpa(q, k, v, is_causal=True) results.append(result) context_layer = torch.cat(results, dim=1) else: - if attention_mask is None: - context_layer = F.scaled_dot_product_attention(query_layer, key, - value, is_causal=True) - else: - context_layer = F.scaled_dot_product_attention(query_layer, key, - value, attention_mask) + context_layer = glm_sdpa(query_layer, key, value, is_causal=True) context_layer = context_layer.to(query_layer.dtype) if use_cache: @@ -535,145 +569,34 @@ def chatglm2_attention_forward_8eb45c( def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2: - query_layer = query_layer.permute(1, 2, 0, 3) - L, S = query_layer.shape[2], key_layer.shape[2] - if attention_mask is None and L == S: - batch_size, n_head, seq_len, head_dim = query_layer.shape - if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): - # split second dim to block size = 8 - block_size = 8 - query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) - key_split = torch.split(key_layer, block_size, dim=1) - value_split = torch.split(value_layer, block_size, dim=1) - results = [] - for q, k, v in zip(query_split, key_split, value_split): - result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) - results.append(result) - context_layer = torch.cat(results, dim=1) - else: - context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), - key_layer, - value_layer, - is_causal=True).to(key_layer.dtype) + query_layer = query_layer.permute(1, 2, 0, 3) + L, S = query_layer.shape[2], key_layer.shape[2] + batch_size, n_head, seq_len, head_dim = query_layer.shape + if attention_mask is None and L == S: + if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): + # split second dim to block size = 8 + block_size = 8 + query_layer = query_layer.to(key_layer.dtype) + query_split = torch.split(query_layer, block_size, dim=1) + key_split = torch.split(key_layer, block_size, dim=1) + value_split = torch.split(value_layer, block_size, dim=1) + results = [] + for q, k, v in zip(query_split, key_split, value_split): + result = glm_sdpa(q, k, v, is_causal=True) + results.append(result) + context_layer = torch.cat(results, dim=1) else: - # attention_mask is not None only when past_key_value is not None and q_len > 1 - if attention_mask is not None: - attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, - device=query_layer.device) - attention_mask = ~attention_mask - if attention_mask.dtype == torch.bool: - attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) - else: - attn_bias += attention_mask - else: - attn_bias = None - - if use_sdp(query_layer.shape[2], key_layer.shape[2], - query_layer.shape[-1], query_layer): - import xe_addons - attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias) - context_layer = attn_output.view(query_layer.shape) - else: - head_dim = query_layer.size(-1) - attn = torch.matmul(query_layer.to(key_layer.dtype), - key_layer.transpose(2, 3)) / math.sqrt(head_dim) - if attn_bias is not None: - attn += attn_bias - attn = F.softmax(attn, dim=-1, - dtype=torch.float32).to(value_layer.dtype) - context_layer = torch.matmul(attn, value_layer) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (-1,) - context_layer = context_layer.reshape(*new_context_layer_shape) + context_layer = glm_sdpa(query_layer, + key_layer, + value_layer, + is_causal=True) else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), - query_layer.size(0), key_layer.size(2)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], - output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - out=matmul_result - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(0), value_layer.size(1), - query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.empty( - output_size[0] * output_size[1], - output_size[2], value_layer.size(-1), dtype=value_layer.dtype, - device=value_layer.device, - ) - torch.bmm(attention_probs, value_layer, out=context_layer) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = glm_sdpa(query_layer, + key_layer, + value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.reshape(*new_context_layer_shape) return context_layer diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py b/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py index ea53aac3..0df6ad1b 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py @@ -21,6 +21,7 @@ import torch from typing import Optional, Tuple, Union, List, Callable, Dict, Any import torch.nn.functional as F from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from ipex_llm.transformers.models.chatglm2 import core_attn_forward_8eb45c import os @@ -179,10 +180,9 @@ def chatglm2_32k_attention_forward( key_layer = key_cache value_layer = value_cache - key_layer = key_layer.permute(2, 0, 1, 3) - value_layer = value_layer.permute(2, 0, 1, 3) - if use_cache: + key_layer = key_layer.permute(2, 0, 1, 3) + value_layer = value_layer.permute(2, 0, 1, 3) if kv_cache is None: kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1) @@ -195,7 +195,7 @@ def chatglm2_32k_attention_forward( # core attention computation # ================================== - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask) # ================= # Output. [sq, b, h] diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 5c7f156b..00a69dcd 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two from ipex_llm.transformers.models.utils import use_sdp -from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor +from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor, glm_sdpa from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim from transformers.modeling_outputs import BaseModelOutputWithPast @@ -300,51 +300,30 @@ def chatglm4_attention_forward( def core_attn_forward(query_layer, key_layer, value_layer, attention_mask): L, S = query_layer.shape[2], key_layer.shape[2] + batch_size, n_head, seq_len, head_dim = query_layer.shape if attention_mask is None and L == S: - batch_size, n_head, seq_len, head_dim = query_layer.shape if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): # split second dim to block size = 8 block_size = 8 - query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) + query_layer = query_layer.to(key_layer.dtype) + query_split = torch.split(query_layer, block_size, dim=1) key_split = torch.split(key_layer, block_size, dim=1) value_split = torch.split(value_layer, block_size, dim=1) results = [] for q, k, v in zip(query_split, key_split, value_split): - result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) + result = glm_sdpa(q, k, v, is_causal=True) results.append(result) context_layer = torch.cat(results, dim=1) else: - context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), - key_layer, - value_layer, - is_causal=True).to(key_layer.dtype) + context_layer = glm_sdpa(query_layer, + key_layer, + value_layer, + is_causal=True) else: - # attention_mask is not None only when past_key_value is not None and q_len > 1 - if attention_mask is not None: - attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, - device=query_layer.device) - attention_mask = ~attention_mask - if attention_mask.dtype == torch.bool: - attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) - else: - attn_bias += attention_mask - else: - attn_bias = None - - if use_sdp(query_layer.shape[2], key_layer.shape[2], - query_layer.shape[-1], query_layer): - import xe_addons - attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias) - context_layer = attn_output.view(query_layer.shape) - else: - head_dim = query_layer.size(-1) - attn = torch.matmul(query_layer.to(key_layer.dtype), - key_layer.transpose(2, 3)) / math.sqrt(head_dim) - if attn_bias is not None: - attn += attn_bias - attn = F.softmax(attn, dim=-1, - dtype=torch.float32).to(value_layer.dtype) - context_layer = torch.matmul(attn, value_layer) + context_layer = glm_sdpa(query_layer, + key_layer, + value_layer, + attention_mask) context_layer = context_layer.transpose(1, 2).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.reshape(*new_context_layer_shape) diff --git a/python/llm/test/inference_gpu/test_transformers_api_attention.py b/python/llm/test/inference_gpu/test_transformers_api_attention.py index 1a3e5bf7..9e00d0c6 100644 --- a/python/llm/test/inference_gpu/test_transformers_api_attention.py +++ b/python/llm/test/inference_gpu/test_transformers_api_attention.py @@ -171,7 +171,7 @@ class Test_Optimize_Gpu_Model: # currently only need to compare the output of one self-attention layer. layer_norm = "transformer.encoder.layers.27.input_layernorm" self_attn = "transformer.encoder.layers.27.self_attention" - lower_bound = 4e-3 + lower_bound = 8e-3 self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) def Mistral_gpu_model(self, Name, Model, Tokenizer, model_path):