LLM: Fix Qwen kv_cache optimization (#9148)

* first commit

* ut pass

* accelerate rotate half by using common util function

* fix style
This commit is contained in:
Ruonan Wang 2023-10-12 15:49:42 +08:00 committed by GitHub
parent 69942d3826
commit b8aee7bb1b

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# #
# Some parts of this file is adapted from # Some parts of this file is adapted from
# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py # https://huggingface.co/Qwen/Qwen-7B-Chat/blob/faf3ff60438d724a7eb78ebed7e2f7c7330c6bd8/modeling_qwen.py
# #
# Copyright (c) Alibaba Cloud. # Copyright (c) Alibaba Cloud.
# #
@ -37,6 +37,7 @@ except ImportError:
rearrange = None rearrange = None
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
apply_rotary_emb_func = None apply_rotary_emb_func = None
@ -48,34 +49,22 @@ logger = logging.get_logger(__name__)
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def _rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs): def apply_rotary_pos_emb(t, freqs):
if apply_rotary_emb_func is not None: cos, sin = freqs
t_ = t.float() rot_dim = freqs[0].shape[-1]
freqs = freqs.squeeze(0).squeeze(1) cos, sin = freqs
cos = freqs[:, : freqs.shape[-1] // 2].cos()
sin = freqs[:, : freqs.shape[-1] // 2].sin()
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
return output
else:
rot_dim = freqs.shape[-1]
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float() t_ = t_.float()
t_pass_ = t_pass_.float() t_pass_ = t_pass_.float()
t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin()) t_ = (t_ * cos) + (rotate_half(t_) * sin)
return torch.cat((t_, t_pass_), dim=-1).type_as(t) return torch.cat((t_, t_pass_), dim=-1).type_as(t)
def qwen_attention_forward( def qwen_attention_forward(
self, self,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
@ -93,42 +82,33 @@ def qwen_attention_forward(
kv_seq_len = hidden_states.size()[1] kv_seq_len = hidden_states.size()[1]
if layer_past: if rotary_pos_emb_list is not None:
# layer past[0] shape: bs * seq_len * head_num * dim cur_len = query.shape[1]
kv_seq_len += layer_past[0].shape[1] if len(rotary_pos_emb_list) == 1:
if ( rotary_pos_emb = rotary_pos_emb_list[0]
self.use_dynamic_ntk rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
and kv_seq_len == hidden_states.size()[1]
and not self.training
):
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
self._ntk_cached = ntk_alpha
else:
ntk_alpha = self._ntk_cached
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
hidden_states.device
)
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = (rotary_pos_emb,) * 2 rotary_pos_emb = (rotary_pos_emb,) * 2
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference # Slice the pos emb for current inference
cur_len = query.shape[1]
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
query = apply_rotary_pos_emb(query, q_pos_emb) query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb) key = apply_rotary_pos_emb(key, k_pos_emb)
else:
query_list = []
key_list = []
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
bsz, _, n_heads, head_dim = key.size() bsz, _, n_heads, head_dim = key.size()
if layer_past is not None: if layer_past is not None:
kv_seq_len += layer_past[0].shape[1]
# past_key, past_value = layer_past[0], layer_past[1] # past_key, past_value = layer_past[0], layer_past[1]
# key = torch.cat((past_key, key), dim=1) # key = torch.cat((past_key, key), dim=1)
# value = torch.cat((past_value, value), dim=1) # value = torch.cat((past_value, value), dim=1)
@ -137,7 +117,7 @@ def qwen_attention_forward(
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, # Support GQA self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
@ -172,8 +152,10 @@ def qwen_attention_forward(
present = None present = None
if self.use_logn_attn and not self.training: if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: if self.use_cache_quantization:
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) seq_start = key[0].size(2) - query.size(1)
seq_end = key[0].size(2)
else:
seq_start = key.size(1) - query.size(1) seq_start = key.size(1) - query.size(1)
seq_end = key.size(1) seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
@ -186,23 +168,33 @@ def qwen_attention_forward(
and query.is_cuda and query.is_cuda
): ):
q, k, v = query, key, value q, k, v = query, key, value
context_layer = self.core_attention_flash(q, k, v) context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
# b s h d -> b s (h d)
context_layer = context_layer.flatten(2, 3).contiguous()
context_layer = rearrange(
context_layer, "b s h d -> b s (h d)"
).contiguous()
else: else:
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3)
if (
registered_causal_mask is None
and self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
and not query.is_cuda
):
invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
attn_output, attn_weight = self._attn( attn_output, attn_weight = self._attn(
query, key, value, attention_mask, head_mask query, key, value, registered_causal_mask, attention_mask, head_mask
) )
context_layer = self._merge_heads( context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim attn_output, self.num_heads, self.head_dim
) )
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
outputs = (attn_output, present) outputs = (attn_output, present)
if output_attentions: if output_attentions:
if ( if (
@ -210,7 +202,7 @@ def qwen_attention_forward(
and flash_attn_unpadded_func is not None and flash_attn_unpadded_func is not None
and not self.is_fp32 and not self.is_fp32
): ):
invalidInputError("Cannot output attentions while using flash-attn") invalidInputError(False, "Cannot output attentions while using flash-attn")
else: else:
outputs += (attn_weight,) outputs += (attn_weight,)