diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm.py b/python/llm/src/bigdl/llm/transformers/models/chatglm.py index 4503e0d2..6c1a0a8a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm.py @@ -22,6 +22,7 @@ import torch import torch.utils.checkpoint import torch.nn.functional as F from typing import Optional, Tuple +from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache def rotate_half(x): @@ -58,43 +59,43 @@ def attention_fn( # query_layer = query_layer.permute(1, 2, 0, 3) cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] + device = query_layer.device if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - past_length = past_key.size(2) - if past_length + cur_length > self.max_cache_length: - self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH - self.kv_cache = (torch.empty(batch_size, - self.num_attention_heads, - self.max_cache_length, - self.hidden_size_per_attention_head,), - torch.empty(batch_size, - self.num_attention_heads, - self.max_cache_length, - self.hidden_size_per_attention_head,)) - self.kv_cache[0][:, :, :past_length, :] = past_key - self.kv_cache[1][:, :, :past_length, :] = past_value - - self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer - self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer - key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :] - value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :] + cache_k, cache_v = layer_past[0], layer_past[1] + cache_k = cache_k.permute(1, 2, 0, 3) + cache_v = cache_v.permute(1, 2, 0, 3) + past_length = cache_k.size(2) + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH + new_cache_k, new_cache_v = create_kv_cache(batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + past_length, + max_cache_length, + dtype=query_layer.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer) elif use_cache: - self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH - self.kv_cache = (torch.empty(batch_size, self.num_attention_heads, - self.max_cache_length, self.hidden_size_per_attention_head,), - torch.empty(batch_size, self.num_attention_heads, - self.max_cache_length, self.hidden_size_per_attention_head,)) - self.kv_cache[0][:, :, :cur_length, :] = key_layer - self.kv_cache[1][:, :, :cur_length, :] = value_layer + key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, cur_length, + max_cache_length, + dtype=query_layer.dtype, device=device) + key_cache[:] = key_layer + value_cache[:] = value_layer + key_layer = key_cache + value_layer = value_cache # seqlen, batch, num_attention_heads, hidden_size_per_attention_head b, nh, seq_len, hidden_size = key_layer.shape if use_cache: - present = (key_layer, value_layer) + present = (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3)) else: present = None @@ -168,6 +169,7 @@ def attention_fn( matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device ) torch.baddbmm( @@ -217,7 +219,8 @@ def attention_fn( # matmul: [b * np, sq, hn] context_layer = torch.empty( output_size[0] * output_size[1], - output_size[2], value_layer.size(-1), dtype=value_layer.dtype,) + output_size[2], value_layer.size(-1), dtype=value_layer.dtype, + device=query_layer.device) torch.bmm(attention_probs, value_layer, out=context_layer) # change view [b, np, sq, hn] diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 68d7c63f..de9a2c4c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -37,6 +37,7 @@ from typing import Optional, Tuple import math import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache def rotate_half(x): @@ -125,35 +126,37 @@ def llama_attention_forward_4_31( if past_key_value is not None: # reuse k, v, self_attention - # key_states = torch.cat([past_key_value[0], key_states], dim=2) - # value_states = torch.cat([past_key_value[1], value_states], dim=2) - if kv_seq_len > self.max_cache_length: - new_cache_key = torch.empty(bsz, self.num_heads, - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim, - device=device) - new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :] + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v - new_cache_value = torch.empty(bsz, self.num_heads, - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim, - device=device) - new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :] - self.kv_cache = (new_cache_key, new_cache_value) - self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) - self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states - self.kv_cache[1][:, :, kv_seq_len-1:kv_seq_len, :] = value_states - key_states = self.kv_cache[0][:, :, :kv_seq_len, :] - value_states = self.kv_cache[1][:, :, :kv_seq_len, :] elif use_cache: - # first token case - self.max_cache_length = max(min(self.max_position_embeddings, 2 * kv_seq_len), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH) - self.kv_cache = (torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim, - dtype=key_states.dtype, device=device), - torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim, - dtype=key_states.dtype, device=device)) - self.kv_cache[0][:, :, :kv_seq_len, :] = key_states - self.kv_cache[1][:, :, :kv_seq_len, :] = value_states + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states past_key_value = (key_states, value_states) if use_cache else None