From 7bbb98abb63c76fac4b78602429ddfc0dca6ee4c Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 18 Jan 2024 16:22:12 +0800 Subject: [PATCH] Disable fused layer norm when using XMX to fix mpt UT (#9933) --- .../llm/src/bigdl/llm/transformers/models/bloom.py | 3 ++- .../llm/src/bigdl/llm/transformers/models/utils.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/bloom.py b/python/llm/src/bigdl/llm/transformers/models/bloom.py index 88c02fc2..ff9d4b6a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/bloom.py +++ b/python/llm/src/bigdl/llm/transformers/models/bloom.py @@ -37,6 +37,7 @@ from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch.nn import functional as F +from bigdl.llm.transformers.models.utils import use_fused_layer_norm from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache @@ -63,7 +64,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: def bloom_layer_norm_forward(self, hidden_states): - if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): + if use_fused_layer_norm(hidden_states, self.training): import linear_q4_0 result = linear_q4_0.fused_layer_norm(hidden_states, [self.weight.size(0)], diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 10c2ffca..620f9102 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -283,3 +283,15 @@ def use_xmx(x: torch.Tensor, qtype: int): 1 < x.size(0) <= 8 ) ) + + +def use_fused_layer_norm(x: torch.Tensor, training: bool): + return ( + not training + and not x.requires_grad + and x.device.type == 'xpu' + and ( + get_xpu_device_type(x) not in ["arc", "flex"] + or x.reshape(-1, x.size(-1)).size(0) == 1 + ) + )