diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 7222d344..d43452cb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -20,6 +20,7 @@ import torch from typing import Optional, Tuple, Union, List, Callable, Dict, Any import torch.nn.functional as F +from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -145,39 +146,38 @@ def chatglm2_attention_forward_8eb45c( # adjust key and value for inference if kv_cache is not None: cache_k, cache_v = kv_cache - past_length = cache_k.size(0) + cache_k = cache_k.permute(1, 2, 0, 3) + cache_v = cache_v.permute(1, 2, 0, 3) + past_length = cache_k.size(2) - if past_length + cur_length > self.max_cache_length: - self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH - self.kv_cache = (torch.empty(batch_size, - self.num_attention_heads_per_partition, - self.max_cache_length, - self.hidden_size_per_attention_head, - device=device), - torch.empty(batch_size, - self.num_attention_heads_per_partition, - self.max_cache_length, - self.hidden_size_per_attention_head, - device=device)) - self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3) - self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3) - self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer - self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH + new_cache_k, new_cache_v = create_kv_cache(batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + past_length, + max_cache_length, + dtype=query_layer.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v - key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :] - value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :] + key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer) elif use_cache: - self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + + max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH - self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition, - self.max_cache_length, self.hidden_size_per_attention_head, - device=device), - torch.empty(batch_size, self.num_attention_heads_per_partition, - self.max_cache_length, self.hidden_size_per_attention_head, - device=device)) - self.kv_cache[0][:, :, :cur_length, :] = key_layer - self.kv_cache[1][:, :, :cur_length, :] = value_layer + key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, cur_length, + max_cache_length, + dtype=query_layer.dtype, device=device) + key_cache[:] = key_layer + value_cache[:] = value_layer + key_layer = key_cache + value_layer = value_cache if use_cache: kv_cache = (key_layer, value_layer) @@ -204,36 +204,14 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1): query_layer = query_layer.permute(1, 2, 0, 3) if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - - if torch.is_autocast_cpu_enabled(): - attention_mask = torch.ones(query_layer.shape[2], - key_layer.shape[2], - dtype=torch.bool).tril(diagonal=0) - attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) - attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) - query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) - key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) - value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask, - is_causal=False) - else: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask, - is_causal=True) + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask, + is_causal=True) else: if attention_mask is not None: - attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) - - if torch.is_autocast_cpu_enabled(): - query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) - key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) - value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) - attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) + attention_mask = ~attention_mask context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py new file mode 100644 index 00000000..58765e2a --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -0,0 +1,48 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch + + +def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): + key_cache_storage = torch.empty(batch_size, num_heads, + max_length, head_dim, + dtype=dtype, device=device) + value_cache_storage = torch.empty(batch_size, num_heads, + max_length, head_dim, + dtype=dtype, device=device) + + key_cache = key_cache_storage.as_strided((batch_size, num_heads, + current_length, head_dim), + key_cache_storage.stride(), + storage_offset=0) + value_cache = value_cache_storage.as_strided((batch_size, num_heads, + current_length, head_dim), + value_cache_storage.stride(), + storage_offset=0) + return key_cache, value_cache + + +def append_kv_cache(cache_k, cache_v, key_states, value_states): + new_size = (cache_k.size(0), + cache_k.size(1), + cache_k.size(2) + key_states.size(2), + cache_k.size(3)) + new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) + new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states + new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) + new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states + return new_cache_k, new_cache_v