Fix chatglm2 multi-turn streamchat (#8867)

This commit is contained in:
Yang Wang 2023-09-01 13:13:49 +08:00 committed by GitHub
parent c06f1ca93e
commit 242c9d6036

View file

@ -145,7 +145,7 @@ def chatglm2_attention_forward_8eb45c(
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
past_length = cache_k.size(2)
past_length = cache_k.size(0)
if past_length + cur_length > self.max_cache_length:
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
@ -159,8 +159,8 @@ def chatglm2_attention_forward_8eb45c(
self.max_cache_length,
self.hidden_size_per_attention_head,
device=device))
self.kv_cache[0][:, :, :past_length, :] = cache_k
self.kv_cache[1][:, :, :past_length, :] = cache_v
self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3)
self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3)
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
@ -196,7 +196,7 @@ def chatglm2_attention_forward_8eb45c(
output = self.dense(context_layer)
return output, kv_cache
return output, (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3))
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
@ -228,6 +228,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
else:
if attention_mask is not None:
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
if torch.is_autocast_cpu_enabled():
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())