diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index d4c3e326..e99cea5a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -239,21 +239,27 @@ def chatglm2_attention_forward_8eb45c( key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) if self.multi_query_attention: - key_length = key_layer.size(0) - query_group_size = self.num_attention_heads_per_partition // \ - self.num_multi_query_groups_per_partition - key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] - key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) - key_layer = key_layer.contiguous().view((batch_size, - self.num_attention_heads_per_partition, - key_length, - self.hidden_size_per_attention_head)) - value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) - value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) - value_layer = value_layer.contiguous().view((batch_size, + if device.type == "xpu" and batch_size > 1: # use beam_search for generation. + # If batch_size > 1 on gpu, permute key/value_layer to [bs, np, sl, hn] + # to reduce memory usage. Otherwise,expend key/value_layer to [bs, nh, sl, hn]. + key_layer = key_layer.permute(1, 2, 0, 3) # [bs, np, sl, hn] + value_layer = value_layer.permute(1, 2, 0, 3) # [bs, np, sl, hn] + else: + key_length = key_layer.size(0) + query_group_size = self.num_attention_heads_per_partition // \ + self.num_multi_query_groups_per_partition + key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] + key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) + key_layer = key_layer.contiguous().view((batch_size, self.num_attention_heads_per_partition, key_length, self.hidden_size_per_attention_head)) + value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] + value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) + value_layer = value_layer.contiguous().view((batch_size, + self.num_attention_heads_per_partition, + key_length, + self.hidden_size_per_attention_head)) # adjust key and value for inference if kv_cache is not None: @@ -264,13 +270,26 @@ def chatglm2_attention_forward_8eb45c( if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_cache_k, new_cache_v = extend_kv_cache(batch_size, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - past_length, - max_cache_length, - dtype=query_layer.dtype, - device=device) + if device.type == "xpu" and batch_size > 1: # use beam_search for generation. + # If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring + # generation correctness. + # Set the num_heads in init_kv_cache to np, ensuring that the tensors of + # new_cache_k/v and key/value_layer have the same size. + new_cache_k, new_cache_v = init_kv_cache(batch_size, + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + past_length, + max_cache_length, + dtype=query_layer.dtype, + device=device) + else: + new_cache_k, new_cache_v = extend_kv_cache(batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + past_length, + max_cache_length, + dtype=query_layer.dtype, + device=device) new_cache_k[:] = cache_k new_cache_v[:] = cache_v cache_k = new_cache_k @@ -279,18 +298,33 @@ def chatglm2_attention_forward_8eb45c( key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer) elif use_cache: - max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH - key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, cur_length, + + if device.type == "xpu" and batch_size > 1: # use beam_search for generation. + # Ensure the tensors of key/value_cache and key/value_layer have the same size. + nums_per_partition = self.num_multi_query_groups_per_partition + else: + nums_per_partition = self.num_attention_heads_per_partition + + key_cache, value_cache = init_kv_cache(batch_size, + nums_per_partition, + self.hidden_size_per_attention_head, + cur_length, max_cache_length, - dtype=query_layer.dtype, device=device) + dtype=query_layer.dtype, + device=device) key_cache[:] = key_layer value_cache[:] = value_layer key_layer = key_cache value_layer = value_cache + # If batch_size > 1, return tensors with shape [bs, np, sl, hn] as past_key_values. This could + # reduce memory usage as tensors are not expended to [bs, nh, sl, hn]. + # Otherwise, return views of [bs, nh, sl, hn]. + cache_key_layer = key_layer + cache_value_layer = value_layer + if use_cache: kv_cache = (key_layer, value_layer) else: @@ -299,6 +333,29 @@ def chatglm2_attention_forward_8eb45c( # ================================== # core attention computation # ================================== + if device.type == "xpu" and batch_size > 1: # use beam_search for generation. + # If batch_size > 1, expend key/value_layer to [ns, nh, sl, bn] for + # core attention computation. + # The expanded tensors will not be returned as past_key_values. + if self.multi_query_attention: + query_group_size = self.num_attention_heads_per_partition // \ + self.num_multi_query_groups_per_partition + key_layer = key_layer.unsqueeze(-3) + key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) + save_length = key_layer.size(3) + # [bs, np, sl, hn] --> [bs, nh, sl, hn] + key_layer = key_layer.contiguous().view((batch_size, + self.num_attention_heads_per_partition, + save_length, + self.hidden_size_per_attention_head)) + value_layer = value_layer.unsqueeze(-3) + value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) + # [bs, np, sl, hn] --> [bs, nh, sl, hn] + value_layer = value_layer.contiguous().view((batch_size, + self.num_attention_heads_per_partition, + save_length, + self.hidden_size_per_attention_head)) + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= @@ -307,7 +364,7 @@ def chatglm2_attention_forward_8eb45c( output = self.dense(context_layer) - return output, (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3)) + return output, (cache_key_layer.permute(2, 0, 1, 3), cache_value_layer.permute(2, 0, 1, 3)) def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):