Fix chatglm2 attention and kv cache (#8924)

* fix chatglm2 attention

* fix bf16 bug

* make model stateless

* add utils

* cleanup

* fix style
This commit is contained in:
Yang Wang 2023-09-08 09:54:29 +08:00 committed by GitHub
parent b209b8f7b6
commit 25428b22b4
2 changed files with 82 additions and 56 deletions

View file

@ -20,6 +20,7 @@
import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -145,39 +146,38 @@ def chatglm2_attention_forward_8eb45c(
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
past_length = cache_k.size(0)
cache_k = cache_k.permute(1, 2, 0, 3)
cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2)
if past_length + cur_length > self.max_cache_length:
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
self.kv_cache = (torch.empty(batch_size,
self.num_attention_heads_per_partition,
self.max_cache_length,
self.hidden_size_per_attention_head,
device=device),
torch.empty(batch_size,
self.num_attention_heads_per_partition,
self.max_cache_length,
self.hidden_size_per_attention_head,
device=device))
self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3)
self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3)
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
past_length,
max_cache_length,
dtype=query_layer.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :]
value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :]
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
elif use_cache:
self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
+ KV_CACHE_ALLOC_BLOCK_LENGTH
self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition,
self.max_cache_length, self.hidden_size_per_attention_head,
device=device),
torch.empty(batch_size, self.num_attention_heads_per_partition,
self.max_cache_length, self.hidden_size_per_attention_head,
device=device))
self.kv_cache[0][:, :, :cur_length, :] = key_layer
self.kv_cache[1][:, :, :cur_length, :] = value_layer
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, cur_length,
max_cache_length,
dtype=query_layer.dtype, device=device)
key_cache[:] = key_layer
value_cache[:] = value_layer
key_layer = key_cache
value_layer = value_cache
if use_cache:
kv_cache = (key_layer, value_layer)
@ -204,36 +204,14 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
query_layer = query_layer.permute(1, 2, 0, 3)
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
if torch.is_autocast_cpu_enabled():
attention_mask = torch.ones(query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool).tril(diagonal=0)
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attention_mask,
is_causal=False)
else:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attention_mask,
is_causal=True)
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attention_mask,
is_causal=True)
else:
if attention_mask is not None:
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
if torch.is_autocast_cpu_enabled():
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,

View file

@ -0,0 +1,48 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim,
dtype=dtype, device=device)
value_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim,
dtype=dtype, device=device)
key_cache = key_cache_storage.as_strided((batch_size, num_heads,
current_length, head_dim),
key_cache_storage.stride(),
storage_offset=0)
value_cache = value_cache_storage.as_strided((batch_size, num_heads,
current_length, head_dim),
value_cache_storage.stride(),
storage_offset=0)
return key_cache, value_cache
def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_size = (cache_k.size(0),
cache_k.size(1),
cache_k.size(2) + key_states.size(2),
cache_k.size(3))
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
return new_cache_k, new_cache_v