LLM: Fix Qwen kv_cache optimization (#9148)
* first commit * ut pass * accelerate rotate half by using common util function * fix style
This commit is contained in:
parent
69942d3826
commit
b8aee7bb1b
1 changed files with 59 additions and 67 deletions
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
# Some parts of this file is adapted from
|
# Some parts of this file is adapted from
|
||||||
# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py
|
# https://huggingface.co/Qwen/Qwen-7B-Chat/blob/faf3ff60438d724a7eb78ebed7e2f7c7330c6bd8/modeling_qwen.py
|
||||||
#
|
#
|
||||||
# Copyright (c) Alibaba Cloud.
|
# Copyright (c) Alibaba Cloud.
|
||||||
#
|
#
|
||||||
|
|
@ -37,6 +37,7 @@ except ImportError:
|
||||||
rearrange = None
|
rearrange = None
|
||||||
|
|
||||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import rotate_half
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
apply_rotary_emb_func = None
|
apply_rotary_emb_func = None
|
||||||
|
|
@ -48,34 +49,22 @@ logger = logging.get_logger(__name__)
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
def _rotate_half(x):
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
|
||||||
x1, x2 = x.unbind(dim=-2)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(t, freqs):
|
def apply_rotary_pos_emb(t, freqs):
|
||||||
if apply_rotary_emb_func is not None:
|
cos, sin = freqs
|
||||||
t_ = t.float()
|
rot_dim = freqs[0].shape[-1]
|
||||||
freqs = freqs.squeeze(0).squeeze(1)
|
cos, sin = freqs
|
||||||
cos = freqs[:, : freqs.shape[-1] // 2].cos()
|
|
||||||
sin = freqs[:, : freqs.shape[-1] // 2].sin()
|
|
||||||
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
|
|
||||||
return output
|
|
||||||
else:
|
|
||||||
rot_dim = freqs.shape[-1]
|
|
||||||
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
|
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
|
||||||
t_ = t_.float()
|
t_ = t_.float()
|
||||||
t_pass_ = t_pass_.float()
|
t_pass_ = t_pass_.float()
|
||||||
t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
|
t_ = (t_ * cos) + (rotate_half(t_) * sin)
|
||||||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||||
|
|
||||||
|
|
||||||
def qwen_attention_forward(
|
def qwen_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
||||||
|
registered_causal_mask: Optional[torch.Tensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
|
@ -93,42 +82,33 @@ def qwen_attention_forward(
|
||||||
|
|
||||||
kv_seq_len = hidden_states.size()[1]
|
kv_seq_len = hidden_states.size()[1]
|
||||||
|
|
||||||
if layer_past:
|
if rotary_pos_emb_list is not None:
|
||||||
# layer past[0] shape: bs * seq_len * head_num * dim
|
cur_len = query.shape[1]
|
||||||
kv_seq_len += layer_past[0].shape[1]
|
if len(rotary_pos_emb_list) == 1:
|
||||||
if (
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
self.use_dynamic_ntk
|
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||||
and kv_seq_len == hidden_states.size()[1]
|
|
||||||
and not self.training
|
|
||||||
):
|
|
||||||
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
|
||||||
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
|
||||||
ntk_alpha = max(ntk_alpha, 1)
|
|
||||||
self._ntk_cached = ntk_alpha
|
|
||||||
else:
|
|
||||||
ntk_alpha = self._ntk_cached
|
|
||||||
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
|
|
||||||
hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if rotary_pos_emb is not None:
|
|
||||||
if isinstance(rotary_pos_emb, tuple):
|
|
||||||
rotary_pos_emb = rotary_pos_emb
|
|
||||||
else:
|
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
|
|
||||||
if rotary_pos_emb is not None:
|
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||||
# Slice the pos emb for current inference
|
# Slice the pos emb for current inference
|
||||||
cur_len = query.shape[1]
|
|
||||||
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
|
|
||||||
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
|
|
||||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||||
|
else:
|
||||||
|
query_list = []
|
||||||
|
key_list = []
|
||||||
|
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
|
||||||
|
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||||
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
|
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||||
|
# Slice the pos emb for current inference
|
||||||
|
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
||||||
|
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
||||||
|
query = torch.cat(query_list, dim=0)
|
||||||
|
key = torch.cat(key_list, dim=0)
|
||||||
|
|
||||||
bsz, _, n_heads, head_dim = key.size()
|
bsz, _, n_heads, head_dim = key.size()
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
|
kv_seq_len += layer_past[0].shape[1]
|
||||||
# past_key, past_value = layer_past[0], layer_past[1]
|
# past_key, past_value = layer_past[0], layer_past[1]
|
||||||
# key = torch.cat((past_key, key), dim=1)
|
# key = torch.cat((past_key, key), dim=1)
|
||||||
# value = torch.cat((past_value, value), dim=1)
|
# value = torch.cat((past_value, value), dim=1)
|
||||||
|
|
@ -137,7 +117,7 @@ def qwen_attention_forward(
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_heads, # Support GQA
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
|
@ -172,8 +152,10 @@ def qwen_attention_forward(
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
if self.use_logn_attn and not self.training:
|
if self.use_logn_attn and not self.training:
|
||||||
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
|
if self.use_cache_quantization:
|
||||||
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
seq_start = key[0].size(2) - query.size(1)
|
||||||
|
seq_end = key[0].size(2)
|
||||||
|
else:
|
||||||
seq_start = key.size(1) - query.size(1)
|
seq_start = key.size(1) - query.size(1)
|
||||||
seq_end = key.size(1)
|
seq_end = key.size(1)
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
||||||
|
|
@ -186,23 +168,33 @@ def qwen_attention_forward(
|
||||||
and query.is_cuda
|
and query.is_cuda
|
||||||
):
|
):
|
||||||
q, k, v = query, key, value
|
q, k, v = query, key, value
|
||||||
context_layer = self.core_attention_flash(q, k, v)
|
context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
# b s h d -> b s (h d)
|
||||||
|
context_layer = context_layer.flatten(2, 3).contiguous()
|
||||||
|
|
||||||
context_layer = rearrange(
|
|
||||||
context_layer, "b s h d -> b s (h d)"
|
|
||||||
).contiguous()
|
|
||||||
else:
|
else:
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
|
if not self.use_cache_quantization:
|
||||||
key = key.permute(0, 2, 1, 3)
|
key = key.permute(0, 2, 1, 3)
|
||||||
value = value.permute(0, 2, 1, 3)
|
value = value.permute(0, 2, 1, 3)
|
||||||
|
if (
|
||||||
|
registered_causal_mask is None
|
||||||
|
and self.use_flash_attn
|
||||||
|
and flash_attn_unpadded_func is not None
|
||||||
|
and not self.is_fp32
|
||||||
|
and not query.is_cuda
|
||||||
|
):
|
||||||
|
invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
||||||
attn_output, attn_weight = self._attn(
|
attn_output, attn_weight = self._attn(
|
||||||
query, key, value, attention_mask, head_mask
|
query, key, value, registered_causal_mask, attention_mask, head_mask
|
||||||
)
|
)
|
||||||
context_layer = self._merge_heads(
|
context_layer = self._merge_heads(
|
||||||
attn_output, self.num_heads, self.head_dim
|
attn_output, self.num_heads, self.head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, present)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
if (
|
if (
|
||||||
|
|
@ -210,7 +202,7 @@ def qwen_attention_forward(
|
||||||
and flash_attn_unpadded_func is not None
|
and flash_attn_unpadded_func is not None
|
||||||
and not self.is_fp32
|
and not self.is_fp32
|
||||||
):
|
):
|
||||||
invalidInputError("Cannot output attentions while using flash-attn")
|
invalidInputError(False, "Cannot output attentions while using flash-attn")
|
||||||
else:
|
else:
|
||||||
outputs += (attn_weight,)
|
outputs += (attn_weight,)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue