LLM: Support GQA on llama kvcache (#8938)

* support GQA
This commit is contained in:
Zhao Changmin 2023-09-12 12:18:40 +08:00 committed by GitHub
parent 2d81521019
commit dcaa4dc130

View file

@ -131,7 +131,7 @@ def llama_attention_forward_4_31(
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads, self.num_key_value_heads, # Support GQA
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
@ -147,7 +147,7 @@ def llama_attention_forward_4_31(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = create_kv_cache(bsz,
self.num_heads, self.num_key_value_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
max_cache_length, max_cache_length,