optimize Decilm 7b (#9922)
* optimize deci * update * decilm attension forward
This commit is contained in:
parent
bcaeb05272
commit
97f0cd8975
2 changed files with 193 additions and 0 deletions
|
|
@ -934,6 +934,19 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.RwkvSelfAttention,
|
||||
rwkv_attention_forward)
|
||||
elif model.config.model_type == "deci":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.decilm import decilm_attention_forward_4_35_2
|
||||
convert_forward(model,
|
||||
module.LlamaRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.LlamaMLP,
|
||||
llama_mlp_forward)
|
||||
convert_forward(model,
|
||||
module.DeciLMAttention,
|
||||
decilm_attention_forward_4_35_2, )
|
||||
elif model.config.model_type == "rwkv5":
|
||||
# rwkv v5
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
180
python/llm/src/bigdl/llm/transformers/models/decilm.py
Normal file
180
python/llm/src/bigdl/llm/transformers/models/decilm.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from bigdl.llm.transformers.models.llama import should_use_fuse_rope, repeat_kv
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
||||
def decilm_attention_forward_4_35_2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
is_decode = past_key_value is not None
|
||||
device = hidden_states.device
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
|
||||
self.config.pretraining_tp)
|
||||
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
|
||||
// self.config.pretraining_tp, dim=0)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i])
|
||||
for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i])
|
||||
for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i])
|
||||
for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama")
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_key_value_heads, # Support GQA
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
cache_v = new_cache_v
|
||||
|
||||
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
|
||||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
self.num_key_value_heads,
|
||||
self.head_dim,
|
||||
kv_seq_len,
|
||||
max_cache_length,
|
||||
dtype=key_states.dtype,
|
||||
device=device)
|
||||
new_key_states[:] = key_states
|
||||
new_value_states[:] = value_states
|
||||
key_states = new_key_states
|
||||
value_states = new_value_states
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if is_decode:
|
||||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
|
||||
is_causal=False,
|
||||
attn_mask=attention_mask)
|
||||
attn_output = attn_output.contiguous().view(bsz, q_len, self.hidden_size)
|
||||
|
||||
else:
|
||||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
|
||||
is_causal=attention_mask is None,
|
||||
attn_mask=attention_mask)
|
||||
|
||||
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||
f"`attn_output` should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
|
||||
dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
Loading…
Reference in a new issue