MPT model optimize for long sequence (#9020)
* mpt_long_seq * update * update * update * style * style2 * update
This commit is contained in:
parent
9126abdf9b
commit
028a6d9383
2 changed files with 158 additions and 1 deletions
|
|
@ -173,6 +173,15 @@ def optimize(model):
|
||||||
module.SelfAttention,
|
module.SelfAttention,
|
||||||
chatglm_attention_forward
|
chatglm_attention_forward
|
||||||
)
|
)
|
||||||
|
elif "mpt" in model.config._name_or_path:
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
|
||||||
|
module = importlib.import_module(attention_module_name)
|
||||||
|
from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward
|
||||||
|
convert_forward(model,
|
||||||
|
module.MultiheadAttention,
|
||||||
|
mpt_multihead_attention_forward
|
||||||
|
)
|
||||||
elif "gptj" in model.config.model_type:
|
elif "gptj" in model.config.model_type:
|
||||||
# dolly-v1-6b
|
# dolly-v1-6b
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
@ -263,5 +272,4 @@ def optimize(model):
|
||||||
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
|
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
|
||||||
gptneox_attention_forward
|
gptneox_attention_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
149
python/llm/src/bigdl/llm/transformers/models/mpt.py
Normal file
149
python/llm/src/bigdl/llm/transformers/models/mpt.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py
|
||||||
|
#
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
import math
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
|
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
||||||
|
attention_mask=None, is_causal=True, needs_weights=False):
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
if self.clip_qkv:
|
||||||
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
|
(query, key, value) = qkv.chunk(3, dim=2)
|
||||||
|
key_padding_mask = attention_mask
|
||||||
|
if self.qk_ln:
|
||||||
|
dtype = query.dtype
|
||||||
|
query = self.q_ln(query).to(dtype)
|
||||||
|
key = self.k_ln(key).to(dtype)
|
||||||
|
(context, attn_weights, past_key_value) = \
|
||||||
|
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
is_causal=is_causal,
|
||||||
|
dropout_p=self.attn_dropout_p,
|
||||||
|
training=self.training,
|
||||||
|
needs_weights=needs_weights)
|
||||||
|
return (self.out_proj(context), attn_weights, past_key_value)
|
||||||
|
|
||||||
|
|
||||||
|
def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
|
||||||
|
past_key_value=None,
|
||||||
|
softmax_scale=None,
|
||||||
|
attn_bias=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
is_causal=False,
|
||||||
|
dropout_p=0.0,
|
||||||
|
training=False,
|
||||||
|
needs_weights=False,
|
||||||
|
multiquery=False):
|
||||||
|
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
||||||
|
bsz, n_heads, q_len, head_dim = q.size()
|
||||||
|
device = q.device
|
||||||
|
kv_n_heads = 1 if multiquery else n_heads
|
||||||
|
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
||||||
|
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
||||||
|
kv_seq_len = k.shape[-1]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if len(past_key_value) != 0:
|
||||||
|
# k = torch.cat([past_key_value[0], k], dim=3)
|
||||||
|
# v = torch.cat([past_key_value[1], v], dim=2)
|
||||||
|
cache_k = past_key_value[0].transpose(2, 3)
|
||||||
|
cache_v = past_key_value[1]
|
||||||
|
kv_seq_len += cache_k.shape[-2]
|
||||||
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
# allocate new
|
||||||
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
|
kv_n_heads, # Support GQA
|
||||||
|
head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=cache_k.dtype,
|
||||||
|
device=device)
|
||||||
|
new_cache_k[:] = cache_k
|
||||||
|
new_cache_v[:] = cache_v
|
||||||
|
cache_k = new_cache_k
|
||||||
|
cache_v = new_cache_v
|
||||||
|
key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v)
|
||||||
|
k = key_states.transpose(2, 3)
|
||||||
|
v = value_states
|
||||||
|
else:
|
||||||
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
||||||
|
kv_n_heads,
|
||||||
|
head_dim,
|
||||||
|
kv_seq_len,
|
||||||
|
max_cache_length,
|
||||||
|
dtype=k.dtype,
|
||||||
|
device=device)
|
||||||
|
new_key_states[:] = k.transpose(2, 3)
|
||||||
|
new_value_states[:] = v
|
||||||
|
k = new_key_states.transpose(2, 3)
|
||||||
|
v = new_value_states
|
||||||
|
past_key_value = (k, v)
|
||||||
|
(b, _, s_q, d) = q.shape
|
||||||
|
s_k = k.size(-1)
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = 1 / math.sqrt(d)
|
||||||
|
attn_weight = q.matmul(k) * softmax_scale
|
||||||
|
if attn_bias is not None:
|
||||||
|
_s_q = max(0, attn_bias.size(2) - s_q)
|
||||||
|
_s_k = max(0, attn_bias.size(3) - s_k)
|
||||||
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
||||||
|
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \
|
||||||
|
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
||||||
|
invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) '
|
||||||
|
f'is expected to broadcast to shape: {attn_weight.shape}.')
|
||||||
|
attn_weight = attn_weight + attn_bias
|
||||||
|
min_val = torch.finfo(q.dtype).min
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
if attn_bias is not None:
|
||||||
|
warnings.warn('Propogating key_padding_mask to the attention module '
|
||||||
|
+ 'and applying it within the attention module can cause '
|
||||||
|
+ 'unneccessary computation/memory usage. Consider integrating '
|
||||||
|
+ 'into attn_bias once and passing that to each attention '
|
||||||
|
+ 'module instead.')
|
||||||
|
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
||||||
|
if is_causal and (not q.size(2) == 1):
|
||||||
|
s = max(s_q, s_k)
|
||||||
|
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
||||||
|
causal_mask = causal_mask.tril()
|
||||||
|
causal_mask = causal_mask.to(torch.bool)
|
||||||
|
causal_mask = ~causal_mask
|
||||||
|
causal_mask = causal_mask[-s_q:, -s_k:]
|
||||||
|
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
||||||
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||||
|
if dropout_p:
|
||||||
|
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p,
|
||||||
|
training=training, inplace=True)
|
||||||
|
out = attn_weight.to(v.dtype).matmul(v)
|
||||||
|
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||||
|
if needs_weights:
|
||||||
|
return (out, attn_weight, past_key_value)
|
||||||
|
return (out, None, past_key_value)
|
||||||
Loading…
Reference in a new issue