diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index e26405d5..7222d344 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -145,7 +145,7 @@ def chatglm2_attention_forward_8eb45c( # adjust key and value for inference if kv_cache is not None: cache_k, cache_v = kv_cache - past_length = cache_k.size(2) + past_length = cache_k.size(0) if past_length + cur_length > self.max_cache_length: self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH @@ -159,8 +159,8 @@ def chatglm2_attention_forward_8eb45c( self.max_cache_length, self.hidden_size_per_attention_head, device=device)) - self.kv_cache[0][:, :, :past_length, :] = cache_k - self.kv_cache[1][:, :, :past_length, :] = cache_v + self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3) + self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3) self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer @@ -196,7 +196,7 @@ def chatglm2_attention_forward_8eb45c( output = self.dense(context_layer) - return output, kv_cache + return output, (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3)) def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask): @@ -228,6 +228,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio else: if attention_mask is not None: attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + if torch.is_autocast_cpu_enabled(): query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) key_layer = key_layer.to(torch.get_autocast_cpu_dtype())