MPT model optimize for long sequence (#9020)
* mpt_long_seq * update * update * update * style * style2 * update
This commit is contained in:
parent
9126abdf9b
commit
028a6d9383
2 changed files with 158 additions and 1 deletions
|
|
@ -173,6 +173,15 @@ def optimize(model):
|
|||
module.SelfAttention,
|
||||
chatglm_attention_forward
|
||||
)
|
||||
elif "mpt" in model.config._name_or_path:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
|
||||
module = importlib.import_module(attention_module_name)
|
||||
from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward
|
||||
convert_forward(model,
|
||||
module.MultiheadAttention,
|
||||
mpt_multihead_attention_forward
|
||||
)
|
||||
elif "gptj" in model.config.model_type:
|
||||
# dolly-v1-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
@ -263,5 +272,4 @@ def optimize(model):
|
|||
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
|
||||
gptneox_attention_forward
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
|||
149
python/llm/src/bigdl/llm/transformers/models/mpt.py
Normal file
149
python/llm/src/bigdl/llm/transformers/models/mpt.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py
|
||||
#
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
||||
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
||||
attention_mask=None, is_causal=True, needs_weights=False):
|
||||
qkv = self.Wqkv(x)
|
||||
if self.clip_qkv:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
(query, key, value) = qkv.chunk(3, dim=2)
|
||||
key_padding_mask = attention_mask
|
||||
if self.qk_ln:
|
||||
dtype = query.dtype
|
||||
query = self.q_ln(query).to(dtype)
|
||||
key = self.k_ln(key).to(dtype)
|
||||
(context, attn_weights, past_key_value) = \
|
||||
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
|
||||
past_key_value=past_key_value,
|
||||
softmax_scale=self.softmax_scale,
|
||||
attn_bias=attn_bias,
|
||||
key_padding_mask=key_padding_mask,
|
||||
is_causal=is_causal,
|
||||
dropout_p=self.attn_dropout_p,
|
||||
training=self.training,
|
||||
needs_weights=needs_weights)
|
||||
return (self.out_proj(context), attn_weights, past_key_value)
|
||||
|
||||
|
||||
def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
|
||||
past_key_value=None,
|
||||
softmax_scale=None,
|
||||
attn_bias=None,
|
||||
key_padding_mask=None,
|
||||
is_causal=False,
|
||||
dropout_p=0.0,
|
||||
training=False,
|
||||
needs_weights=False,
|
||||
multiquery=False):
|
||||
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
||||
bsz, n_heads, q_len, head_dim = q.size()
|
||||
device = q.device
|
||||
kv_n_heads = 1 if multiquery else n_heads
|
||||
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
||||
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
||||
kv_seq_len = k.shape[-1]
|
||||
if past_key_value is not None:
|
||||
if len(past_key_value) != 0:
|
||||
# k = torch.cat([past_key_value[0], k], dim=3)
|
||||
# v = torch.cat([past_key_value[1], v], dim=2)
|
||||
cache_k = past_key_value[0].transpose(2, 3)
|
||||
cache_v = past_key_value[1]
|
||||
kv_seq_len += cache_k.shape[-2]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
kv_n_heads, # Support GQA
|
||||
head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
cache_v = new_cache_v
|
||||
key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v)
|
||||
k = key_states.transpose(2, 3)
|
||||
v = value_states
|
||||
else:
|
||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||
kv_n_heads,
|
||||
head_dim,
|
||||
kv_seq_len,
|
||||
max_cache_length,
|
||||
dtype=k.dtype,
|
||||
device=device)
|
||||
new_key_states[:] = k.transpose(2, 3)
|
||||
new_value_states[:] = v
|
||||
k = new_key_states.transpose(2, 3)
|
||||
v = new_value_states
|
||||
past_key_value = (k, v)
|
||||
(b, _, s_q, d) = q.shape
|
||||
s_k = k.size(-1)
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1 / math.sqrt(d)
|
||||
attn_weight = q.matmul(k) * softmax_scale
|
||||
if attn_bias is not None:
|
||||
_s_q = max(0, attn_bias.size(2) - s_q)
|
||||
_s_k = max(0, attn_bias.size(3) - s_k)
|
||||
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
||||
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \
|
||||
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
||||
invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) '
|
||||
f'is expected to broadcast to shape: {attn_weight.shape}.')
|
||||
attn_weight = attn_weight + attn_bias
|
||||
min_val = torch.finfo(q.dtype).min
|
||||
if key_padding_mask is not None:
|
||||
if attn_bias is not None:
|
||||
warnings.warn('Propogating key_padding_mask to the attention module '
|
||||
+ 'and applying it within the attention module can cause '
|
||||
+ 'unneccessary computation/memory usage. Consider integrating '
|
||||
+ 'into attn_bias once and passing that to each attention '
|
||||
+ 'module instead.')
|
||||
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
||||
if is_causal and (not q.size(2) == 1):
|
||||
s = max(s_q, s_k)
|
||||
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
||||
causal_mask = causal_mask.tril()
|
||||
causal_mask = causal_mask.to(torch.bool)
|
||||
causal_mask = ~causal_mask
|
||||
causal_mask = causal_mask[-s_q:, -s_k:]
|
||||
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
if dropout_p:
|
||||
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p,
|
||||
training=training, inplace=True)
|
||||
out = attn_weight.to(v.dtype).matmul(v)
|
||||
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||
if needs_weights:
|
||||
return (out, attn_weight, past_key_value)
|
||||
return (out, None, past_key_value)
|
||||
Loading…
Reference in a new issue