Fix chatglm2 multi-turn streamchat (#8867)
This commit is contained in:
parent
c06f1ca93e
commit
242c9d6036
1 changed files with 5 additions and 4 deletions
|
|
@ -145,7 +145,7 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
# adjust key and value for inference
|
# adjust key and value for inference
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
cache_k, cache_v = kv_cache
|
cache_k, cache_v = kv_cache
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(0)
|
||||||
|
|
||||||
if past_length + cur_length > self.max_cache_length:
|
if past_length + cur_length > self.max_cache_length:
|
||||||
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
|
@ -159,8 +159,8 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
self.max_cache_length,
|
self.max_cache_length,
|
||||||
self.hidden_size_per_attention_head,
|
self.hidden_size_per_attention_head,
|
||||||
device=device))
|
device=device))
|
||||||
self.kv_cache[0][:, :, :past_length, :] = cache_k
|
self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3)
|
||||||
self.kv_cache[1][:, :, :past_length, :] = cache_v
|
self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3)
|
||||||
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
||||||
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
|
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
|
||||||
|
|
||||||
|
|
@ -196,7 +196,7 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
|
|
||||||
output = self.dense(context_layer)
|
output = self.dense(context_layer)
|
||||||
|
|
||||||
return output, kv_cache
|
return output, (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3))
|
||||||
|
|
||||||
|
|
||||||
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
|
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
|
||||||
|
|
@ -228,6 +228,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
else:
|
else:
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
||||||
|
|
||||||
if torch.is_autocast_cpu_enabled():
|
if torch.is_autocast_cpu_enabled():
|
||||||
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
||||||
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
|
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue