diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 1801cc64..bf579e45 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -955,4 +955,14 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.RwkvSelfAttention, rwkv_attention_forward) + elif model.config.model_type == "gpt_bigcode": + # starcoder + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper + _attn = _attn_wrapper(module.GPTBigCodeAttention._attn) + replace_func(model, + module.GPTBigCodeAttention, + "_attn", + _attn) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py b/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py new file mode 100644 index 00000000..600fbc1d --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py @@ -0,0 +1,29 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def _attn_wrapper(origin_attn): + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_output, attn_weights = origin_attn(self, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + head_mask=head_mask) + if query.device.type == 'xpu' and 1 < query.numel() // query.size(-1) <= 64: + attn_output = attn_output.clone() + return attn_output, attn_weights + return _attn diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index ca49bb8d..2b910cdc 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -310,6 +310,6 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool): and x.device.type == 'xpu' and ( get_xpu_device_type(x) not in ["arc", "flex"] - or x.reshape(-1, x.size(-1)).size(0) == 1 + or x.numel() // x.size(-1) == 1 ) )