[LLM] Optimize ChatGLM2 kv_cache to support beam_search on ARC (#9579)
* optimize kv_cache to support beam_search on Arc * correctness test update * fix query_length issue * simplify implementation * only enable the optimization on gpu device * limit the beam_search support only enabled with gpu device and batch_size > 1 * add comments for beam_search case and revert ut change * meet comments * add more comments to describe the differece between multi-cases
This commit is contained in:
parent
c64e2248ef
commit
284e7697b1
1 changed files with 81 additions and 24 deletions
|
|
@ -239,21 +239,27 @@ def chatglm2_attention_forward_8eb45c(
|
|||
key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb)
|
||||
|
||||
if self.multi_query_attention:
|
||||
key_length = key_layer.size(0)
|
||||
query_group_size = self.num_attention_heads_per_partition // \
|
||||
self.num_multi_query_groups_per_partition
|
||||
key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn]
|
||||
key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
key_layer = key_layer.contiguous().view((batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
key_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||
value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
value_layer = value_layer.contiguous().view((batch_size,
|
||||
if device.type == "xpu" and batch_size > 1: # use beam_search for generation.
|
||||
# If batch_size > 1 on gpu, permute key/value_layer to [bs, np, sl, hn]
|
||||
# to reduce memory usage. Otherwise,expend key/value_layer to [bs, nh, sl, hn].
|
||||
key_layer = key_layer.permute(1, 2, 0, 3) # [bs, np, sl, hn]
|
||||
value_layer = value_layer.permute(1, 2, 0, 3) # [bs, np, sl, hn]
|
||||
else:
|
||||
key_length = key_layer.size(0)
|
||||
query_group_size = self.num_attention_heads_per_partition // \
|
||||
self.num_multi_query_groups_per_partition
|
||||
key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn]
|
||||
key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
key_layer = key_layer.contiguous().view((batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
key_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn]
|
||||
value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
value_layer = value_layer.contiguous().view((batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
key_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
|
||||
# adjust key and value for inference
|
||||
if kv_cache is not None:
|
||||
|
|
@ -264,13 +270,26 @@ def chatglm2_attention_forward_8eb45c(
|
|||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
past_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype,
|
||||
device=device)
|
||||
if device.type == "xpu" and batch_size > 1: # use beam_search for generation.
|
||||
# If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring
|
||||
# generation correctness.
|
||||
# Set the num_heads in init_kv_cache to np, ensuring that the tensors of
|
||||
# new_cache_k/v and key/value_layer have the same size.
|
||||
new_cache_k, new_cache_v = init_kv_cache(batch_size,
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
past_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype,
|
||||
device=device)
|
||||
else:
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
past_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
|
|
@ -279,18 +298,33 @@ def chatglm2_attention_forward_8eb45c(
|
|||
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
||||
|
||||
elif use_cache:
|
||||
|
||||
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head, cur_length,
|
||||
|
||||
if device.type == "xpu" and batch_size > 1: # use beam_search for generation.
|
||||
# Ensure the tensors of key/value_cache and key/value_layer have the same size.
|
||||
nums_per_partition = self.num_multi_query_groups_per_partition
|
||||
else:
|
||||
nums_per_partition = self.num_attention_heads_per_partition
|
||||
|
||||
key_cache, value_cache = init_kv_cache(batch_size,
|
||||
nums_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
cur_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype, device=device)
|
||||
dtype=query_layer.dtype,
|
||||
device=device)
|
||||
key_cache[:] = key_layer
|
||||
value_cache[:] = value_layer
|
||||
key_layer = key_cache
|
||||
value_layer = value_cache
|
||||
|
||||
# If batch_size > 1, return tensors with shape [bs, np, sl, hn] as past_key_values. This could
|
||||
# reduce memory usage as tensors are not expended to [bs, nh, sl, hn].
|
||||
# Otherwise, return views of [bs, nh, sl, hn].
|
||||
cache_key_layer = key_layer
|
||||
cache_value_layer = value_layer
|
||||
|
||||
if use_cache:
|
||||
kv_cache = (key_layer, value_layer)
|
||||
else:
|
||||
|
|
@ -299,6 +333,29 @@ def chatglm2_attention_forward_8eb45c(
|
|||
# ==================================
|
||||
# core attention computation
|
||||
# ==================================
|
||||
if device.type == "xpu" and batch_size > 1: # use beam_search for generation.
|
||||
# If batch_size > 1, expend key/value_layer to [ns, nh, sl, bn] for
|
||||
# core attention computation.
|
||||
# The expanded tensors will not be returned as past_key_values.
|
||||
if self.multi_query_attention:
|
||||
query_group_size = self.num_attention_heads_per_partition // \
|
||||
self.num_multi_query_groups_per_partition
|
||||
key_layer = key_layer.unsqueeze(-3)
|
||||
key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
save_length = key_layer.size(3)
|
||||
# [bs, np, sl, hn] --> [bs, nh, sl, hn]
|
||||
key_layer = key_layer.contiguous().view((batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
save_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
value_layer = value_layer.unsqueeze(-3)
|
||||
value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
# [bs, np, sl, hn] --> [bs, nh, sl, hn]
|
||||
value_layer = value_layer.contiguous().view((batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
save_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
|
||||
# =================
|
||||
|
|
@ -307,7 +364,7 @@ def chatglm2_attention_forward_8eb45c(
|
|||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3))
|
||||
return output, (cache_key_layer.permute(2, 0, 1, 3), cache_value_layer.permute(2, 0, 1, 3))
|
||||
|
||||
|
||||
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
|
||||
|
|
|
|||
Loading…
Reference in a new issue