ipex-llm/python/llm/src/ipex_llm/transformers/models/mpt.py

189 lines
9.3 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py
#
import warnings
import torch
from einops import rearrange
import math
import torch.nn.functional as F
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
attention_mask=None, is_causal=True,
needs_weights=False, rotary_emb_w_meta_info=None,
**kwargs):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
(query, key, value) = qkv.chunk(3, dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
if rotary_emb_w_meta_info is not None:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)
if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)
kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len)
[key, value] = torch.unbind(kv, dim=2)
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
(cos, sin) = rotary_emb(value, seq_len)
if is_transformers_version_gte('4.36'):
query, key = apply_rotary_pos_emb(query,
key,
cos,
sin,
offset_info,
unsqueeze_dim=2)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
(context, attn_weights, past_key_value) = \
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights)
return (self.out_proj(context), attn_weights, past_key_value)
def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False):
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
bsz, n_heads, q_len, head_dim = q.size()
device = q.device
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
kv_seq_len = k.shape[-1]
if past_key_value is not None:
if len(past_key_value) != 0:
# k = torch.cat([past_key_value[0], k], dim=3)
# v = torch.cat([past_key_value[1], v], dim=2)
cache_k = past_key_value[0].transpose(2, 3)
cache_v = past_key_value[1]
kv_seq_len += cache_k.shape[-2]
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
kv_n_heads, # Support GQA
head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v)
k = key_states.transpose(2, 3)
v = value_states
else:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
kv_n_heads,
head_dim,
kv_seq_len,
max_cache_length,
dtype=k.dtype,
device=device)
new_key_states[:] = k.transpose(2, 3)
new_value_states[:] = v
k = new_key_states.transpose(2, 3)
v = new_value_states
past_key_value = (k, v)
(b, _, s_q, d) = q.shape
s_k = k.size(-1)
if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None:
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) '
f'is expected to broadcast to shape: {attn_weight.shape}.')
attn_weight = attn_weight + attn_bias
min_val = torch.finfo(q.dtype).min
if key_padding_mask is not None:
if attn_bias is not None:
warnings.warn('Propogating key_padding_mask to the attention module '
+ 'and applying it within the attention module can cause '
+ 'unneccessary computation/memory usage. Consider integrating '
+ 'into attn_bias once and passing that to each attention '
+ 'module instead.')
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
if is_causal and (not q.size(2) == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
causal_mask = causal_mask.tril()
causal_mask = causal_mask.to(torch.bool)
causal_mask = ~causal_mask
causal_mask = causal_mask[-s_q:, -s_k:]
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
attn_weight = torch.softmax(attn_weight, dim=-1)
if dropout_p:
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p,
training=training, inplace=True)
out = attn_weight.to(v.dtype).matmul(v)
out = rearrange(out, 'b h s d -> b s (h d)')
if needs_weights:
return (out, attn_weight, past_key_value)
return (out, None, past_key_value)