diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 6f05cb59..518c8a5d 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -173,6 +173,15 @@ def optimize(model): module.SelfAttention, chatglm_attention_forward ) + elif "mpt" in model.config._name_or_path: + modeling_module_name = model.__class__.__module__ + attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention" + module = importlib.import_module(attention_module_name) + from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward + convert_forward(model, + module.MultiheadAttention, + mpt_multihead_attention_forward + ) elif "gptj" in model.config.model_type: # dolly-v1-6b modeling_module_name = model.__class__.__module__ @@ -263,5 +272,4 @@ def optimize(model): transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention, gptneox_attention_forward ) - return model diff --git a/python/llm/src/bigdl/llm/transformers/models/mpt.py b/python/llm/src/bigdl/llm/transformers/models/mpt.py new file mode 100644 index 00000000..fd8e28b7 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/mpt.py @@ -0,0 +1,149 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py +# + +import warnings +import torch +from einops import rearrange +import math +import torch.nn.functional as F +from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache + + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None, + attention_mask=None, is_causal=True, needs_weights=False): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.chunk(3, dim=2) + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + (context, attn_weights, past_key_value) = \ + mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights) + return (self.out_proj(context), attn_weights, past_key_value) + + +def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False): + q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + bsz, n_heads, q_len, head_dim = q.size() + device = q.device + kv_n_heads = 1 if multiquery else n_heads + k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) + v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) + kv_seq_len = k.shape[-1] + if past_key_value is not None: + if len(past_key_value) != 0: + # k = torch.cat([past_key_value[0], k], dim=3) + # v = torch.cat([past_key_value[1], v], dim=2) + cache_k = past_key_value[0].transpose(2, 3) + cache_v = past_key_value[1] + kv_seq_len += cache_k.shape[-2] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = extend_kv_cache(bsz, + kv_n_heads, # Support GQA + head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v) + k = key_states.transpose(2, 3) + v = value_states + else: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache(bsz, + kv_n_heads, + head_dim, + kv_seq_len, + max_cache_length, + dtype=k.dtype, + device=device) + new_key_states[:] = k.transpose(2, 3) + new_value_states[:] = v + k = new_key_states.transpose(2, 3) + v = new_value_states + past_key_value = (k, v) + (b, _, s_q, d) = q.shape + s_k = k.size(-1) + if softmax_scale is None: + softmax_scale = 1 / math.sqrt(d) + attn_weight = q.matmul(k) * softmax_scale + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - s_q) + _s_k = max(0, attn_bias.size(3) - s_k) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \ + or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): + invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) ' + f'is expected to broadcast to shape: {attn_weight.shape}.') + attn_weight = attn_weight + attn_bias + min_val = torch.finfo(q.dtype).min + if key_padding_mask is not None: + if attn_bias is not None: + warnings.warn('Propogating key_padding_mask to the attention module ' + + 'and applying it within the attention module can cause ' + + 'unneccessary computation/memory usage. Consider integrating ' + + 'into attn_bias once and passing that to each attention ' + + 'module instead.') + attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) + if is_causal and (not q.size(2) == 1): + s = max(s_q, s_k) + causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) + causal_mask = causal_mask.tril() + causal_mask = causal_mask.to(torch.bool) + causal_mask = ~causal_mask + causal_mask = causal_mask[-s_q:, -s_k:] + attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p: + attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, + training=training, inplace=True) + out = attn_weight.to(v.dtype).matmul(v) + out = rearrange(out, 'b h s d -> b s (h d)') + if needs_weights: + return (out, attn_weight, past_key_value) + return (out, None, past_key_value)