Make llama attention stateless (#8928)
* Make llama attention stateless * fix style * fix chatglm * fix chatglm xpu
This commit is contained in:
parent
e62eda74b8
commit
16761c58be
2 changed files with 60 additions and 54 deletions
|
|
@ -22,6 +22,7 @@ import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
|
@ -58,43 +59,43 @@ def attention_fn(
|
||||||
# query_layer = query_layer.permute(1, 2, 0, 3)
|
# query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
|
|
||||||
cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
|
cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
|
||||||
|
device = query_layer.device
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past[0], layer_past[1]
|
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||||
past_length = past_key.size(2)
|
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||||
if past_length + cur_length > self.max_cache_length:
|
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||||
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
past_length = cache_k.size(2)
|
||||||
self.kv_cache = (torch.empty(batch_size,
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
self.num_attention_heads,
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
self.max_cache_length,
|
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||||
self.hidden_size_per_attention_head,),
|
self.num_attention_heads_per_partition,
|
||||||
torch.empty(batch_size,
|
self.hidden_size_per_attention_head,
|
||||||
self.num_attention_heads,
|
past_length,
|
||||||
self.max_cache_length,
|
max_cache_length,
|
||||||
self.hidden_size_per_attention_head,))
|
dtype=query_layer.dtype,
|
||||||
self.kv_cache[0][:, :, :past_length, :] = past_key
|
device=device)
|
||||||
self.kv_cache[1][:, :, :past_length, :] = past_value
|
new_cache_k[:] = cache_k
|
||||||
|
new_cache_v[:] = cache_v
|
||||||
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
||||||
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
|
|
||||||
key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :]
|
|
||||||
value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :]
|
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
self.kv_cache = (torch.empty(batch_size, self.num_attention_heads,
|
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||||
self.max_cache_length, self.hidden_size_per_attention_head,),
|
self.hidden_size_per_attention_head, cur_length,
|
||||||
torch.empty(batch_size, self.num_attention_heads,
|
max_cache_length,
|
||||||
self.max_cache_length, self.hidden_size_per_attention_head,))
|
dtype=query_layer.dtype, device=device)
|
||||||
self.kv_cache[0][:, :, :cur_length, :] = key_layer
|
key_cache[:] = key_layer
|
||||||
self.kv_cache[1][:, :, :cur_length, :] = value_layer
|
value_cache[:] = value_layer
|
||||||
|
key_layer = key_cache
|
||||||
|
value_layer = value_cache
|
||||||
|
|
||||||
# seqlen, batch, num_attention_heads, hidden_size_per_attention_head
|
# seqlen, batch, num_attention_heads, hidden_size_per_attention_head
|
||||||
b, nh, seq_len, hidden_size = key_layer.shape
|
b, nh, seq_len, hidden_size = key_layer.shape
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
present = (key_layer, value_layer)
|
present = (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3))
|
||||||
else:
|
else:
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
|
|
@ -168,6 +169,7 @@ def attention_fn(
|
||||||
matmul_result = torch.empty(
|
matmul_result = torch.empty(
|
||||||
output_size[0] * output_size[1],
|
output_size[0] * output_size[1],
|
||||||
output_size[2], output_size[3], dtype=query_layer.dtype,
|
output_size[2], output_size[3], dtype=query_layer.dtype,
|
||||||
|
device=query_layer.device
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.baddbmm(
|
torch.baddbmm(
|
||||||
|
|
@ -217,7 +219,8 @@ def attention_fn(
|
||||||
# matmul: [b * np, sq, hn]
|
# matmul: [b * np, sq, hn]
|
||||||
context_layer = torch.empty(
|
context_layer = torch.empty(
|
||||||
output_size[0] * output_size[1],
|
output_size[0] * output_size[1],
|
||||||
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,)
|
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,
|
||||||
|
device=query_layer.device)
|
||||||
torch.bmm(attention_probs, value_layer, out=context_layer)
|
torch.bmm(attention_probs, value_layer, out=context_layer)
|
||||||
|
|
||||||
# change view [b, np, sq, hn]
|
# change view [b, np, sq, hn]
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ from typing import Optional, Tuple
|
||||||
import math
|
import math
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
|
@ -125,35 +126,37 @@ def llama_attention_forward_4_31(
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
cache_k = past_key_value[0]
|
||||||
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
cache_v = past_key_value[1]
|
||||||
if kv_seq_len > self.max_cache_length:
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
new_cache_key = torch.empty(bsz, self.num_heads,
|
# allocate new
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=cache_k.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
|
new_cache_k[:] = cache_k
|
||||||
|
new_cache_v[:] = cache_v
|
||||||
|
cache_k = new_cache_k
|
||||||
|
cache_v = new_cache_v
|
||||||
|
|
||||||
new_cache_value = torch.empty(bsz, self.num_heads,
|
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim,
|
|
||||||
device=device)
|
|
||||||
new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
|
|
||||||
self.kv_cache = (new_cache_key, new_cache_value)
|
|
||||||
self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
||||||
|
|
||||||
self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states
|
|
||||||
self.kv_cache[1][:, :, kv_seq_len-1:kv_seq_len, :] = value_states
|
|
||||||
key_states = self.kv_cache[0][:, :, :kv_seq_len, :]
|
|
||||||
value_states = self.kv_cache[1][:, :, :kv_seq_len, :]
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
# first token case
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
self.max_cache_length = max(min(self.max_position_embeddings, 2 * kv_seq_len),
|
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH)
|
self.num_heads,
|
||||||
self.kv_cache = (torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim,
|
self.head_dim,
|
||||||
dtype=key_states.dtype, device=device),
|
kv_seq_len,
|
||||||
torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim,
|
max_cache_length,
|
||||||
dtype=key_states.dtype, device=device))
|
dtype=key_states.dtype,
|
||||||
self.kv_cache[0][:, :, :kv_seq_len, :] = key_states
|
device=device)
|
||||||
self.kv_cache[1][:, :, :kv_seq_len, :] = value_states
|
new_key_states[:] = key_states
|
||||||
|
new_value_states[:] = value_states
|
||||||
|
key_states = new_key_states
|
||||||
|
value_states = new_value_states
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue