Fix chatglm2 attention and kv cache (#8924)
* fix chatglm2 attention * fix bf16 bug * make model stateless * add utils * cleanup * fix style
This commit is contained in:
parent
b209b8f7b6
commit
25428b22b4
2 changed files with 82 additions and 56 deletions
|
|
@ -20,6 +20,7 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -145,39 +146,38 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
# adjust key and value for inference
|
# adjust key and value for inference
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
cache_k, cache_v = kv_cache
|
cache_k, cache_v = kv_cache
|
||||||
past_length = cache_k.size(0)
|
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||||
|
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||||
|
past_length = cache_k.size(2)
|
||||||
|
|
||||||
if past_length + cur_length > self.max_cache_length:
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
self.kv_cache = (torch.empty(batch_size,
|
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
self.max_cache_length,
|
|
||||||
self.hidden_size_per_attention_head,
|
self.hidden_size_per_attention_head,
|
||||||
device=device),
|
past_length,
|
||||||
torch.empty(batch_size,
|
max_cache_length,
|
||||||
self.num_attention_heads_per_partition,
|
dtype=query_layer.dtype,
|
||||||
self.max_cache_length,
|
device=device)
|
||||||
self.hidden_size_per_attention_head,
|
new_cache_k[:] = cache_k
|
||||||
device=device))
|
new_cache_v[:] = cache_v
|
||||||
self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3)
|
cache_k = new_cache_k
|
||||||
self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3)
|
cache_v = new_cache_v
|
||||||
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
|
||||||
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
|
|
||||||
|
|
||||||
key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :]
|
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
||||||
value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :]
|
|
||||||
|
|
||||||
elif use_cache:
|
elif use_cache:
|
||||||
self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
|
||||||
|
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition,
|
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||||
self.max_cache_length, self.hidden_size_per_attention_head,
|
self.hidden_size_per_attention_head, cur_length,
|
||||||
device=device),
|
max_cache_length,
|
||||||
torch.empty(batch_size, self.num_attention_heads_per_partition,
|
dtype=query_layer.dtype, device=device)
|
||||||
self.max_cache_length, self.hidden_size_per_attention_head,
|
key_cache[:] = key_layer
|
||||||
device=device))
|
value_cache[:] = value_layer
|
||||||
self.kv_cache[0][:, :, :cur_length, :] = key_layer
|
key_layer = key_cache
|
||||||
self.kv_cache[1][:, :, :cur_length, :] = value_layer
|
value_layer = value_cache
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
kv_cache = (key_layer, value_layer)
|
kv_cache = (key_layer, value_layer)
|
||||||
|
|
@ -204,22 +204,6 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||||
|
|
||||||
if torch.is_autocast_cpu_enabled():
|
|
||||||
attention_mask = torch.ones(query_layer.shape[2],
|
|
||||||
key_layer.shape[2],
|
|
||||||
dtype=torch.bool).tril(diagonal=0)
|
|
||||||
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
|
||||||
attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
|
|
||||||
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
|
||||||
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
|
|
||||||
value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
|
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
|
||||||
key_layer,
|
|
||||||
value_layer,
|
|
||||||
attention_mask,
|
|
||||||
is_causal=False)
|
|
||||||
else:
|
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
|
|
@ -227,13 +211,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
else:
|
else:
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
attention_mask = ~attention_mask
|
||||||
|
|
||||||
if torch.is_autocast_cpu_enabled():
|
|
||||||
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
|
||||||
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
|
|
||||||
value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
|
|
||||||
attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
|
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
|
|
|
||||||
48
python/llm/src/bigdl/llm/transformers/models/utils.py
Normal file
48
python/llm/src/bigdl/llm/transformers/models/utils.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||||
|
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||||
|
max_length, head_dim,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
value_cache_storage = torch.empty(batch_size, num_heads,
|
||||||
|
max_length, head_dim,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
|
||||||
|
key_cache = key_cache_storage.as_strided((batch_size, num_heads,
|
||||||
|
current_length, head_dim),
|
||||||
|
key_cache_storage.stride(),
|
||||||
|
storage_offset=0)
|
||||||
|
value_cache = value_cache_storage.as_strided((batch_size, num_heads,
|
||||||
|
current_length, head_dim),
|
||||||
|
value_cache_storage.stride(),
|
||||||
|
storage_offset=0)
|
||||||
|
return key_cache, value_cache
|
||||||
|
|
||||||
|
|
||||||
|
def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||||
|
new_size = (cache_k.size(0),
|
||||||
|
cache_k.size(1),
|
||||||
|
cache_k.size(2) + key_states.size(2),
|
||||||
|
cache_k.size(3))
|
||||||
|
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
|
||||||
|
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
|
||||||
|
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
|
||||||
|
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
|
||||||
|
return new_cache_k, new_cache_v
|
||||||
Loading…
Reference in a new issue