vLLM: Apply attention optimizations for selective batching (#9758)
* fuse_rope for prefil * apply kv_cache optimizations * apply fast_decoding_path * Re-enable kv_cache optimizations for prefill * reduce KV_CACHE_ALLOC_BLOCK for selective_batching
This commit is contained in:
parent
ed8ed76d4f
commit
daf536fb2d
2 changed files with 126 additions and 85 deletions
|
|
@ -317,6 +317,8 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
padding_mask: Optional[torch.LongTensor] = None,
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# Minimize this value to reduce memory allocation.
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 64
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
# for flash attention
|
||||||
|
|
@ -334,35 +336,48 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
attention_dtype = original_dtype
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
# TODO: decoding fast path
|
# TODO: decoding fast path
|
||||||
# use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
# enough_kv_room = is_enough_kv_cache_room(past_key_value[0])
|
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
|
||||||
# is_q4_0 = self.q_proj.qtype == SYM_INT4
|
is_q4_0 = self.q_proj.qtype == SYM_INT4
|
||||||
# no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
# decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||||
# enough_kv_room and bsz * q_len == 1)
|
bsz * q_len == 1)
|
||||||
|
|
||||||
|
updated_past_key_values = []
|
||||||
# single batch decoding fast path
|
# single batch decoding fast path
|
||||||
# forward_qkv takes will perform QKV projection, rotary position embedding
|
# forward_qkv takes will perform QKV projection, rotary position embedding
|
||||||
# and save the key/value states to cache, then return query states and the
|
# and save the key/value states to cache, then return query states and the
|
||||||
# extended key/value cache
|
# extended key/value cache
|
||||||
# if decoding_fast_path:
|
if decoding_fast_path:
|
||||||
# hidden_states = hidden_states.view(1, -1)
|
past_k = past_key_value[0][0]
|
||||||
# kv_seq_len = past_key_value[0].shape[-2]
|
past_v = past_key_value[0][1]
|
||||||
# cache_k = past_key_value[0]
|
kv_seq_len = past_k.shape[-2]
|
||||||
# cache_v = past_key_value[1]
|
if not enough_kv_room:
|
||||||
# import linear_q4_0
|
new_cache_k, new_cache_v = extend_kv_cache(1,
|
||||||
# query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
self.num_key_value_heads, # Support GQA
|
||||||
# self.q_proj.weight,
|
self.head_dim,
|
||||||
# self.k_proj.weight,
|
kv_seq_len,
|
||||||
# self.v_proj.weight,
|
kv_seq_len +
|
||||||
# position_ids,
|
KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
# cache_k, cache_v,
|
dtype=past_k.dtype,
|
||||||
# self.q_proj.weight.qtype,
|
device=device)
|
||||||
# kv_seq_len,
|
new_cache_k[:] = past_k
|
||||||
# self.head_dim)
|
new_cache_v[:] = past_v
|
||||||
# kv_seq_len += 1
|
past_k = new_cache_k
|
||||||
|
past_v = new_cache_v
|
||||||
# else:
|
hidden_states = hidden_states.view(1, -1)
|
||||||
|
import linear_q4_0
|
||||||
|
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
||||||
|
self.q_proj.weight,
|
||||||
|
self.k_proj.weight,
|
||||||
|
self.v_proj.weight,
|
||||||
|
position_ids,
|
||||||
|
past_k, past_v,
|
||||||
|
self.q_proj.weight.qtype,
|
||||||
|
kv_seq_len,
|
||||||
|
self.head_dim)
|
||||||
|
kv_seq_len += 1
|
||||||
|
else:
|
||||||
if self.config.pretraining_tp > 1:
|
if self.config.pretraining_tp > 1:
|
||||||
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
|
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
|
||||||
else:
|
else:
|
||||||
|
|
@ -381,24 +396,44 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
|
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
|
||||||
|
|
||||||
# TODO: fuse_rope
|
if use_fuse_rope:
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||||
|
key_states,
|
||||||
|
position_ids,
|
||||||
|
"llama")
|
||||||
|
else:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
cos, sin, position_ids, "llama")
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
updated_past_key_values = []
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
batched_attention_output = []
|
batched_attention_output = []
|
||||||
# print(f"type of attention_mask is {type(attention_mask)}")
|
# print(f"type of attention_mask is {type(attention_mask)}")
|
||||||
for batch in range(bsz):
|
for batch in range(bsz):
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value[batch])
|
||||||
past_k, past_v = past_key_value[batch]
|
past_k, past_v = past_key_value[batch]
|
||||||
current_kv_len = past_k.shape[-2] + 1
|
current_kv_len = past_k.shape[-2] + 1
|
||||||
|
if not enough_kv_room:
|
||||||
|
# allocate new
|
||||||
|
new_cache_k, new_cache_v = extend_kv_cache(1,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
self.head_dim,
|
||||||
|
past_k.size(2),
|
||||||
|
current_kv_len +
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=past_k.dtype,
|
||||||
|
device=device)
|
||||||
|
new_cache_k[:] = past_k
|
||||||
|
new_cache_v[:] = past_v
|
||||||
|
past_k = new_cache_k
|
||||||
|
past_v = new_cache_v
|
||||||
|
|
||||||
current_key_states = torch.cat([past_k,
|
current_key_states = key_states[batch: batch + 1, :, :, :]
|
||||||
key_states[batch: batch + 1, :, :, :]], dim=2)
|
current_value_states = value_states[batch: batch + 1, :, :, :]
|
||||||
current_value_states = torch.cat([past_v,
|
current_key_states, current_value_states = append_kv_cache(past_k,
|
||||||
value_states[batch: batch + 1, :, :, :]], dim=2)
|
past_v,
|
||||||
|
current_key_states,
|
||||||
|
current_value_states)
|
||||||
updated_past_key_values.append((current_key_states, current_value_states))
|
updated_past_key_values.append((current_key_states, current_value_states))
|
||||||
|
|
||||||
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
|
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
|
||||||
|
|
@ -434,8 +469,8 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, None, updated_past_key_values
|
return attn_output, None, updated_past_key_values
|
||||||
|
|
||||||
# TODO: Assume always use_cache
|
# Assume always use_cache
|
||||||
# print(f"prefill with batch size {bsz}")
|
# prefill or decoding fast path
|
||||||
for batch in range(bsz):
|
for batch in range(bsz):
|
||||||
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
|
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
|
||||||
value_states[batch: batch+1, :, :, :]))
|
value_states[batch: batch+1, :, :, :]))
|
||||||
|
|
@ -445,6 +480,10 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
|
# Can also happens for decoding fast path
|
||||||
|
if isinstance(attention_mask, list):
|
||||||
|
# For decoding fast path
|
||||||
|
attention_mask = attention_mask[0]
|
||||||
attn_output, attn_weights = native_sdp(query_states,
|
attn_output, attn_weights = native_sdp(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,8 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
if enable_vllm_se_batching:
|
if enable_vllm_se_batching:
|
||||||
attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
|
attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
|
||||||
for x in decoding_attention_mask_list]
|
for x in decoding_attention_mask_list]
|
||||||
position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
|
position_ids = torch.tensor(decoding_position_ids, device=self.device).long()
|
||||||
|
position_ids = position_ids.unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
|
attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
|
||||||
position_ids = None
|
position_ids = None
|
||||||
|
|
@ -214,6 +215,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
if enable_vllm_se_batching:
|
if enable_vllm_se_batching:
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
position_ids.to(self.device)
|
||||||
else:
|
else:
|
||||||
position_ids = None
|
position_ids = None
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue