LLM: Fix Qwen kv_cache optimization (#9148)
* first commit * ut pass * accelerate rotate half by using common util function * fix style
This commit is contained in:
parent
69942d3826
commit
b8aee7bb1b
1 changed files with 59 additions and 67 deletions
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/faf3ff60438d724a7eb78ebed7e2f7c7330c6bd8/modeling_qwen.py
|
||||
#
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
|
|
@ -37,6 +37,7 @@ except ImportError:
|
|||
rearrange = None
|
||||
|
||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
apply_rotary_emb_func = None
|
||||
|
|
@ -48,34 +49,22 @@ logger = logging.get_logger(__name__)
|
|||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
||||
def _rotate_half(x):
|
||||
from einops import rearrange
|
||||
|
||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||
x1, x2 = x.unbind(dim=-2)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(t, freqs):
|
||||
if apply_rotary_emb_func is not None:
|
||||
t_ = t.float()
|
||||
freqs = freqs.squeeze(0).squeeze(1)
|
||||
cos = freqs[:, : freqs.shape[-1] // 2].cos()
|
||||
sin = freqs[:, : freqs.shape[-1] // 2].sin()
|
||||
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
else:
|
||||
rot_dim = freqs.shape[-1]
|
||||
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t_ = t_.float()
|
||||
t_pass_ = t_pass_.float()
|
||||
t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
|
||||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||
cos, sin = freqs
|
||||
rot_dim = freqs[0].shape[-1]
|
||||
cos, sin = freqs
|
||||
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t_ = t_.float()
|
||||
t_pass_ = t_pass_.float()
|
||||
t_ = (t_ * cos) + (rotate_half(t_) * sin)
|
||||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||
|
||||
|
||||
def qwen_attention_forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
||||
registered_causal_mask: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
|
|
@ -93,42 +82,33 @@ def qwen_attention_forward(
|
|||
|
||||
kv_seq_len = hidden_states.size()[1]
|
||||
|
||||
if layer_past:
|
||||
# layer past[0] shape: bs * seq_len * head_num * dim
|
||||
kv_seq_len += layer_past[0].shape[1]
|
||||
if (
|
||||
self.use_dynamic_ntk
|
||||
and kv_seq_len == hidden_states.size()[1]
|
||||
and not self.training
|
||||
):
|
||||
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
||||
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
||||
ntk_alpha = max(ntk_alpha, 1)
|
||||
self._ntk_cached = ntk_alpha
|
||||
else:
|
||||
ntk_alpha = self._ntk_cached
|
||||
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
|
||||
hidden_states.device
|
||||
)
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
if isinstance(rotary_pos_emb, tuple):
|
||||
rotary_pos_emb = rotary_pos_emb
|
||||
else:
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
if rotary_pos_emb_list is not None:
|
||||
cur_len = query.shape[1]
|
||||
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
|
||||
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
|
||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
if len(rotary_pos_emb_list) == 1:
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
else:
|
||||
query_list = []
|
||||
key_list = []
|
||||
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
||||
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
||||
query = torch.cat(query_list, dim=0)
|
||||
key = torch.cat(key_list, dim=0)
|
||||
|
||||
bsz, _, n_heads, head_dim = key.size()
|
||||
|
||||
if layer_past is not None:
|
||||
kv_seq_len += layer_past[0].shape[1]
|
||||
# past_key, past_value = layer_past[0], layer_past[1]
|
||||
# key = torch.cat((past_key, key), dim=1)
|
||||
# value = torch.cat((past_value, value), dim=1)
|
||||
|
|
@ -137,7 +117,7 @@ def qwen_attention_forward(
|
|||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||
self.num_heads, # Support GQA
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
|
|
@ -172,10 +152,12 @@ def qwen_attention_forward(
|
|||
present = None
|
||||
|
||||
if self.use_logn_attn and not self.training:
|
||||
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
|
||||
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
||||
seq_start = key.size(1) - query.size(1)
|
||||
seq_end = key.size(1)
|
||||
if self.use_cache_quantization:
|
||||
seq_start = key[0].size(2) - query.size(1)
|
||||
seq_end = key[0].size(2)
|
||||
else:
|
||||
seq_start = key.size(1) - query.size(1)
|
||||
seq_end = key.size(1)
|
||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
||||
query = query * logn_tensor.expand_as(query)
|
||||
|
||||
|
|
@ -186,23 +168,33 @@ def qwen_attention_forward(
|
|||
and query.is_cuda
|
||||
):
|
||||
q, k, v = query, key, value
|
||||
context_layer = self.core_attention_flash(q, k, v)
|
||||
context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
||||
|
||||
# b s h d -> b s (h d)
|
||||
context_layer = context_layer.flatten(2, 3).contiguous()
|
||||
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> b s (h d)"
|
||||
).contiguous()
|
||||
else:
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
if not self.use_cache_quantization:
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
if (
|
||||
registered_causal_mask is None
|
||||
and self.use_flash_attn
|
||||
and flash_attn_unpadded_func is not None
|
||||
and not self.is_fp32
|
||||
and not query.is_cuda
|
||||
):
|
||||
invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, attention_mask, head_mask
|
||||
query, key, value, registered_causal_mask, attention_mask, head_mask
|
||||
)
|
||||
context_layer = self._merge_heads(
|
||||
attn_output, self.num_heads, self.head_dim
|
||||
)
|
||||
|
||||
attn_output = self.c_proj(context_layer)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
if (
|
||||
|
|
@ -210,7 +202,7 @@ def qwen_attention_forward(
|
|||
and flash_attn_unpadded_func is not None
|
||||
and not self.is_fp32
|
||||
):
|
||||
invalidInputError("Cannot output attentions while using flash-attn")
|
||||
invalidInputError(False, "Cannot output attentions while using flash-attn")
|
||||
else:
|
||||
outputs += (attn_weight,)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue