parent
2d81521019
commit
dcaa4dc130
1 changed files with 2 additions and 2 deletions
|
|
@ -131,7 +131,7 @@ def llama_attention_forward_4_31(
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_key_value_heads, # Support GQA
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
|
@ -147,7 +147,7 @@ def llama_attention_forward_4_31(
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = create_kv_cache(bsz,
|
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_key_value_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue