LLM: Support GQA on llama kvcache (#8938)

* support GQA
This commit is contained in:
Zhao Changmin 2023-09-12 12:18:40 +08:00 committed by GitHub
parent 2d81521019
commit dcaa4dc130

View file

@ -131,7 +131,7 @@ def llama_attention_forward_4_31(
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
@ -147,7 +147,7 @@ def llama_attention_forward_4_31(
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
self.num_heads,
self.num_key_value_heads,
self.head_dim,
kv_seq_len,
max_cache_length,