189 lines
9.3 KiB
Python
189 lines
9.3 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Some parts of this file is adapted from
|
|
# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py
|
|
#
|
|
|
|
import warnings
|
|
import torch
|
|
from einops import rearrange
|
|
import math
|
|
import torch.nn.functional as F
|
|
from ipex_llm.utils.common import invalidInputError
|
|
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
|
|
|
import os
|
|
|
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
|
|
|
|
|
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
|
attention_mask=None, is_causal=True,
|
|
needs_weights=False, rotary_emb_w_meta_info=None,
|
|
**kwargs):
|
|
qkv = self.Wqkv(x)
|
|
if self.clip_qkv:
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
(query, key, value) = qkv.chunk(3, dim=2)
|
|
key_padding_mask = attention_mask
|
|
if self.qk_ln:
|
|
dtype = query.dtype
|
|
query = self.q_ln(query).to(dtype)
|
|
key = self.k_ln(key).to(dtype)
|
|
|
|
if rotary_emb_w_meta_info is not None:
|
|
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
|
seq_len = rotary_emb_w_meta_info['seq_len']
|
|
offset_info = rotary_emb_w_meta_info['offset_info']
|
|
bsz, seqlen = query.shape[:2]
|
|
query = query.view(bsz, seqlen, -1, self.head_dim)
|
|
key = key.view(bsz, seqlen, -1, self.head_dim)
|
|
|
|
if rotary_emb_w_meta_info['impl'] == 'dail':
|
|
value = value.view(bsz, seqlen, -1, self.head_dim)
|
|
|
|
kv = torch.stack([key, value], dim=2)
|
|
query, kv = rotary_emb(query,
|
|
kv,
|
|
seqlen_offset=offset_info,
|
|
max_seqlen=seq_len)
|
|
[key, value] = torch.unbind(kv, dim=2)
|
|
|
|
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
|
elif rotary_emb_w_meta_info['impl'] == 'hf':
|
|
(cos, sin) = rotary_emb(value, seq_len)
|
|
if is_transformers_version_gte('4.36'):
|
|
query, key = apply_rotary_pos_emb(query,
|
|
key,
|
|
cos,
|
|
sin,
|
|
offset_info,
|
|
unsqueeze_dim=2)
|
|
else:
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
query, key = apply_rotary_pos_emb(query, key, cos, sin,
|
|
offset_info)
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
|
|
(context, attn_weights, past_key_value) = \
|
|
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
|
|
past_key_value=past_key_value,
|
|
softmax_scale=self.softmax_scale,
|
|
attn_bias=attn_bias,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
dropout_p=self.attn_dropout_p,
|
|
training=self.training,
|
|
needs_weights=needs_weights)
|
|
return (self.out_proj(context), attn_weights, past_key_value)
|
|
|
|
|
|
def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
|
|
past_key_value=None,
|
|
softmax_scale=None,
|
|
attn_bias=None,
|
|
key_padding_mask=None,
|
|
is_causal=False,
|
|
dropout_p=0.0,
|
|
training=False,
|
|
needs_weights=False,
|
|
multiquery=False):
|
|
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
|
bsz, n_heads, q_len, head_dim = q.size()
|
|
device = q.device
|
|
kv_n_heads = 1 if multiquery else n_heads
|
|
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
|
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
|
kv_seq_len = k.shape[-1]
|
|
if past_key_value is not None:
|
|
if len(past_key_value) != 0:
|
|
# k = torch.cat([past_key_value[0], k], dim=3)
|
|
# v = torch.cat([past_key_value[1], v], dim=2)
|
|
cache_k = past_key_value[0].transpose(2, 3)
|
|
cache_v = past_key_value[1]
|
|
kv_seq_len += cache_k.shape[-2]
|
|
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
|
# allocate new
|
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
|
kv_n_heads, # Support GQA
|
|
head_dim,
|
|
cache_k.size(2),
|
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
dtype=cache_k.dtype,
|
|
device=device)
|
|
new_cache_k[:] = cache_k
|
|
new_cache_v[:] = cache_v
|
|
cache_k = new_cache_k
|
|
cache_v = new_cache_v
|
|
key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v)
|
|
k = key_states.transpose(2, 3)
|
|
v = value_states
|
|
else:
|
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
new_key_states, new_value_states = init_kv_cache(bsz,
|
|
kv_n_heads,
|
|
head_dim,
|
|
kv_seq_len,
|
|
max_cache_length,
|
|
dtype=k.dtype,
|
|
device=device)
|
|
new_key_states[:] = k.transpose(2, 3)
|
|
new_value_states[:] = v
|
|
k = new_key_states.transpose(2, 3)
|
|
v = new_value_states
|
|
past_key_value = (k, v)
|
|
(b, _, s_q, d) = q.shape
|
|
s_k = k.size(-1)
|
|
if softmax_scale is None:
|
|
softmax_scale = 1 / math.sqrt(d)
|
|
attn_weight = q.matmul(k) * softmax_scale
|
|
if attn_bias is not None:
|
|
_s_q = max(0, attn_bias.size(2) - s_q)
|
|
_s_k = max(0, attn_bias.size(3) - s_k)
|
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
|
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \
|
|
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
|
invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) '
|
|
f'is expected to broadcast to shape: {attn_weight.shape}.')
|
|
attn_weight = attn_weight + attn_bias
|
|
min_val = torch.finfo(q.dtype).min
|
|
if key_padding_mask is not None:
|
|
if attn_bias is not None:
|
|
warnings.warn('Propogating key_padding_mask to the attention module '
|
|
+ 'and applying it within the attention module can cause '
|
|
+ 'unneccessary computation/memory usage. Consider integrating '
|
|
+ 'into attn_bias once and passing that to each attention '
|
|
+ 'module instead.')
|
|
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
|
if is_causal and (not q.size(2) == 1):
|
|
s = max(s_q, s_k)
|
|
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
|
causal_mask = causal_mask.tril()
|
|
causal_mask = causal_mask.to(torch.bool)
|
|
causal_mask = ~causal_mask
|
|
causal_mask = causal_mask[-s_q:, -s_k:]
|
|
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
if dropout_p:
|
|
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p,
|
|
training=training, inplace=True)
|
|
out = attn_weight.to(v.dtype).matmul(v)
|
|
out = rearrange(out, 'b h s d -> b s (h d)')
|
|
if needs_weights:
|
|
return (out, attn_weight, past_key_value)
|
|
return (out, None, past_key_value)
|