MPT model optimize for long sequence (#9020)

* mpt_long_seq

* update

* update

* update

* style

* style2

* update
This commit is contained in:
Jiao Wang 2023-09-21 21:27:23 -07:00 committed by GitHub
parent 9126abdf9b
commit 028a6d9383
2 changed files with 158 additions and 1 deletions

View file

@ -173,6 +173,15 @@ def optimize(model):
module.SelfAttention, module.SelfAttention,
chatglm_attention_forward chatglm_attention_forward
) )
elif "mpt" in model.config._name_or_path:
modeling_module_name = model.__class__.__module__
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
module = importlib.import_module(attention_module_name)
from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward
convert_forward(model,
module.MultiheadAttention,
mpt_multihead_attention_forward
)
elif "gptj" in model.config.model_type: elif "gptj" in model.config.model_type:
# dolly-v1-6b # dolly-v1-6b
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
@ -263,5 +272,4 @@ def optimize(model):
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention, transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
gptneox_attention_forward gptneox_attention_forward
) )
return model return model

View file

@ -0,0 +1,149 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py
#
import warnings
import torch
from einops import rearrange
import math
import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
attention_mask=None, is_causal=True, needs_weights=False):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
(query, key, value) = qkv.chunk(3, dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
(context, attn_weights, past_key_value) = \
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights)
return (self.out_proj(context), attn_weights, past_key_value)
def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False):
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
bsz, n_heads, q_len, head_dim = q.size()
device = q.device
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
kv_seq_len = k.shape[-1]
if past_key_value is not None:
if len(past_key_value) != 0:
# k = torch.cat([past_key_value[0], k], dim=3)
# v = torch.cat([past_key_value[1], v], dim=2)
cache_k = past_key_value[0].transpose(2, 3)
cache_v = past_key_value[1]
kv_seq_len += cache_k.shape[-2]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
kv_n_heads, # Support GQA
head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v)
k = key_states.transpose(2, 3)
v = value_states
else:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
kv_n_heads,
head_dim,
kv_seq_len,
max_cache_length,
dtype=k.dtype,
device=device)
new_key_states[:] = k.transpose(2, 3)
new_value_states[:] = v
k = new_key_states.transpose(2, 3)
v = new_value_states
past_key_value = (k, v)
(b, _, s_q, d) = q.shape
s_k = k.size(-1)
if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None:
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) '
f'is expected to broadcast to shape: {attn_weight.shape}.')
attn_weight = attn_weight + attn_bias
min_val = torch.finfo(q.dtype).min
if key_padding_mask is not None:
if attn_bias is not None:
warnings.warn('Propogating key_padding_mask to the attention module '
+ 'and applying it within the attention module can cause '
+ 'unneccessary computation/memory usage. Consider integrating '
+ 'into attn_bias once and passing that to each attention '
+ 'module instead.')
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
if is_causal and (not q.size(2) == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
causal_mask = causal_mask.tril()
causal_mask = causal_mask.to(torch.bool)
causal_mask = ~causal_mask
causal_mask = causal_mask[-s_q:, -s_k:]
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
attn_weight = torch.softmax(attn_weight, dim=-1)
if dropout_p:
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p,
training=training, inplace=True)
out = attn_weight.to(v.dtype).matmul(v)
out = rearrange(out, 'b h s d -> b s (h d)')
if needs_weights:
return (out, attn_weight, past_key_value)
return (out, None, past_key_value)