Fix chatglm2 attention and kv cache (#8924)
* fix chatglm2 attention * fix bf16 bug * make model stateless * add utils * cleanup * fix style
This commit is contained in:
parent
b209b8f7b6
commit
25428b22b4
2 changed files with 82 additions and 56 deletions
|
|
@ -20,6 +20,7 @@
|
|||
import torch
|
||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -145,39 +146,38 @@ def chatglm2_attention_forward_8eb45c(
|
|||
# adjust key and value for inference
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
past_length = cache_k.size(0)
|
||||
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
past_length = cache_k.size(2)
|
||||
|
||||
if past_length + cur_length > self.max_cache_length:
|
||||
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
self.kv_cache = (torch.empty(batch_size,
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.max_cache_length,
|
||||
self.hidden_size_per_attention_head,
|
||||
device=device),
|
||||
torch.empty(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.max_cache_length,
|
||||
self.hidden_size_per_attention_head,
|
||||
device=device))
|
||||
self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3)
|
||||
self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3)
|
||||
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
||||
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
|
||||
past_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
cache_v = new_cache_v
|
||||
|
||||
key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :]
|
||||
value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :]
|
||||
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
||||
|
||||
elif use_cache:
|
||||
self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||
|
||||
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition,
|
||||
self.max_cache_length, self.hidden_size_per_attention_head,
|
||||
device=device),
|
||||
torch.empty(batch_size, self.num_attention_heads_per_partition,
|
||||
self.max_cache_length, self.hidden_size_per_attention_head,
|
||||
device=device))
|
||||
self.kv_cache[0][:, :, :cur_length, :] = key_layer
|
||||
self.kv_cache[1][:, :, :cur_length, :] = value_layer
|
||||
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head, cur_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype, device=device)
|
||||
key_cache[:] = key_layer
|
||||
value_cache[:] = value_layer
|
||||
key_layer = key_cache
|
||||
value_layer = value_cache
|
||||
|
||||
if use_cache:
|
||||
kv_cache = (key_layer, value_layer)
|
||||
|
|
@ -204,22 +204,6 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
|||
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
attention_mask = torch.ones(query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=torch.bool).tril(diagonal=0)
|
||||
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
||||
attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
|
||||
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
||||
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
|
||||
value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
is_causal=False)
|
||||
else:
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
|
|
@ -227,13 +211,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
|||
is_causal=True)
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
||||
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
||||
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
|
||||
value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
|
||||
attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
|
||||
attention_mask = ~attention_mask
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
|
|
|
|||
48
python/llm/src/bigdl/llm/transformers/models/utils.py
Normal file
48
python/llm/src/bigdl/llm/transformers/models/utils.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||
max_length, head_dim,
|
||||
dtype=dtype, device=device)
|
||||
value_cache_storage = torch.empty(batch_size, num_heads,
|
||||
max_length, head_dim,
|
||||
dtype=dtype, device=device)
|
||||
|
||||
key_cache = key_cache_storage.as_strided((batch_size, num_heads,
|
||||
current_length, head_dim),
|
||||
key_cache_storage.stride(),
|
||||
storage_offset=0)
|
||||
value_cache = value_cache_storage.as_strided((batch_size, num_heads,
|
||||
current_length, head_dim),
|
||||
value_cache_storage.stride(),
|
||||
storage_offset=0)
|
||||
return key_cache, value_cache
|
||||
|
||||
|
||||
def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||
new_size = (cache_k.size(0),
|
||||
cache_k.size(1),
|
||||
cache_k.size(2) + key_states.size(2),
|
||||
cache_k.size(3))
|
||||
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
|
||||
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
|
||||
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
|
||||
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
|
||||
return new_cache_k, new_cache_v
|
||||
Loading…
Reference in a new issue