From d299f108d0bc54fbcfb56c62eed6d8ad25cd32fa Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Thu, 28 Dec 2023 09:11:59 +0800 Subject: [PATCH] update falcon attention forward. (#9796) --- python/llm/src/bigdl/llm/transformers/convert.py | 15 ++++++++------- .../src/bigdl/llm/transformers/models/falcon.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2f2b502d..44f98379 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -530,7 +530,7 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) if "RWForCausalLM" in model.config.architectures: - if hasattr(model.config, "multi_query"): + if model.config.hidden_size == 4544: # falcon-7b need to check performance drop after kv cache support. # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b # convert_forward(model, @@ -546,12 +546,13 @@ def _optimize_post(model, lightweight_bmm=False): rw_attention_forward_40b ) elif "FalconForCausalLM" in model.config.architectures: - # falcon-180b - from bigdl.llm.transformers.models.falcon import falcon_attention_forward - convert_forward(model, - module.FalconAttention, - falcon_attention_forward - ) + if model.config.hidden_size != 4544: + # falcon-180b and new falcon-40b + from bigdl.llm.transformers.models.falcon import falcon_attention_forward + convert_forward(model, + module.FalconAttention, + falcon_attention_forward + ) elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696: # baichuan2 if model.config.hidden_size == 4096: diff --git a/python/llm/src/bigdl/llm/transformers/models/falcon.py b/python/llm/src/bigdl/llm/transformers/models/falcon.py index 3a2c565d..d5fb455c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/falcon.py +++ b/python/llm/src/bigdl/llm/transformers/models/falcon.py @@ -443,18 +443,18 @@ def falcon_attention_forward( if layer_past is not None: kv_length += layer_past[0].shape[-2] query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim) - key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim) - value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.view(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.view(batch_size, num_kv_heads, query_length, self.head_dim) device = hidden_states.device if layer_past is not None: # reuse k, v, self_attention - cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim) - cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) + cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim) + cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache( batch_size, - self.num_heads, + num_kv_heads, self.head_dim, cache_k.size(2), kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH, @@ -472,7 +472,7 @@ def falcon_attention_forward( max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH new_key_states, new_value_states = init_kv_cache( batch_size, - self.num_heads, + num_kv_heads, self.head_dim, kv_length, max_cache_length, @@ -485,8 +485,8 @@ def falcon_attention_forward( value_layer = new_value_states query_layer = query_layer.view(batch_size * self.num_heads, -1, self.head_dim) - key_layer = key_layer.view(batch_size * self.num_heads, -1, self.head_dim) - value_layer = value_layer.view(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.view(batch_size * num_kv_heads, -1, self.head_dim) + value_layer = value_layer.view(batch_size * num_kv_heads, -1, self.head_dim) _, kv_length, _ = key_layer.shape if use_cache: present = (key_layer, value_layer)