fix starcoder (#9975)
This commit is contained in:
parent
be5836bee1
commit
f82782cd3b
3 changed files with 40 additions and 1 deletions
|
|
@ -955,4 +955,14 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.RwkvSelfAttention,
|
||||
rwkv_attention_forward)
|
||||
elif model.config.model_type == "gpt_bigcode":
|
||||
# starcoder
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper
|
||||
_attn = _attn_wrapper(module.GPTBigCodeAttention._attn)
|
||||
replace_func(model,
|
||||
module.GPTBigCodeAttention,
|
||||
"_attn",
|
||||
_attn)
|
||||
return model
|
||||
|
|
|
|||
29
python/llm/src/bigdl/llm/transformers/models/gptbigcode.py
Normal file
29
python/llm/src/bigdl/llm/transformers/models/gptbigcode.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
def _attn_wrapper(origin_attn):
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
attn_output, attn_weights = origin_attn(self,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
if query.device.type == 'xpu' and 1 < query.numel() // query.size(-1) <= 64:
|
||||
attn_output = attn_output.clone()
|
||||
return attn_output, attn_weights
|
||||
return _attn
|
||||
|
|
@ -310,6 +310,6 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
|||
and x.device.type == 'xpu'
|
||||
and (
|
||||
get_xpu_device_type(x) not in ["arc", "flex"]
|
||||
or x.reshape(-1, x.size(-1)).size(0) == 1
|
||||
or x.numel() // x.size(-1) == 1
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue