From 82255f972665c5e452844e4c9f91e9df4928c611 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Mon, 11 Dec 2023 09:26:13 +0800 Subject: [PATCH] Enable fused layernorm (#9614) * bloom layernorm * fix * layernorm * fix * fix * fix * style fix * fix * replace nn.LayerNorm --- python/llm/src/bigdl/llm/transformers/convert.py | 6 ++++++ .../llm/src/bigdl/llm/transformers/models/bloom.py | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 92944992..02730699 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -392,6 +392,12 @@ def _optimize_post(model, lightweight_bmm=False): # todo implement 4.28.0 ~ 4.30.2 pass + # convert all nn.LayerNorm + from bigdl.llm.transformers.models.bloom import bloom_layer_norm_forward + convert_forward(model, + nn.LayerNorm, + bloom_layer_norm_forward) + if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'): # chatglm2-6b-32k diff --git a/python/llm/src/bigdl/llm/transformers/models/bloom.py b/python/llm/src/bigdl/llm/transformers/models/bloom.py index e44a26c8..7daaba78 100644 --- a/python/llm/src/bigdl/llm/transformers/models/bloom.py +++ b/python/llm/src/bigdl/llm/transformers/models/bloom.py @@ -62,6 +62,19 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: return out +def bloom_layer_norm_forward(self, hidden_states): + if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): + import linear_q4_0 + hidden_states = linear_q4_0.fused_layer_norm(hidden_states, + [self.weight.size(0)], + self.weight, + self.bias, + self.eps) + return hidden_states + else: + return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) + + def bloom_attention_forward( self, hidden_states: torch.Tensor,