From dcaa4dc130ec978cd7ab597c2f78b54dadf45320 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Tue, 12 Sep 2023 12:18:40 +0800 Subject: [PATCH] LLM: Support GQA on llama kvcache (#8938) * support GQA --- python/llm/src/bigdl/llm/transformers/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index de9a2c4c..b9b94af3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -131,7 +131,7 @@ def llama_attention_forward_4_31( if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): # allocate new new_cache_k, new_cache_v = create_kv_cache(bsz, - self.num_heads, + self.num_key_value_heads, # Support GQA self.head_dim, cache_k.size(2), kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, @@ -147,7 +147,7 @@ def llama_attention_forward_4_31( elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH new_key_states, new_value_states = create_kv_cache(bsz, - self.num_heads, + self.num_key_value_heads, self.head_dim, kv_seq_len, max_cache_length,