update falcon attention forward. (#9796)

This commit is contained in:
Cengguang Zhang 2023-12-28 09:11:59 +08:00 committed by GitHub
parent a5e5c3daec
commit d299f108d0
2 changed files with 16 additions and 15 deletions

View file

@ -530,7 +530,7 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if "RWForCausalLM" in model.config.architectures:
if hasattr(model.config, "multi_query"):
if model.config.hidden_size == 4544:
# falcon-7b need to check performance drop after kv cache support.
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
# convert_forward(model,
@ -546,12 +546,13 @@ def _optimize_post(model, lightweight_bmm=False):
rw_attention_forward_40b
)
elif "FalconForCausalLM" in model.config.architectures:
# falcon-180b
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
convert_forward(model,
module.FalconAttention,
falcon_attention_forward
)
if model.config.hidden_size != 4544:
# falcon-180b and new falcon-40b
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
convert_forward(model,
module.FalconAttention,
falcon_attention_forward
)
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# baichuan2
if model.config.hidden_size == 4096:

View file

@ -443,18 +443,18 @@ def falcon_attention_forward(
if layer_past is not None:
kv_length += layer_past[0].shape[-2]
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
key_layer = key_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
device = hidden_states.device
if layer_past is not None:
# reuse k, v, self_attention
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
self.num_heads,
num_kv_heads,
self.head_dim,
cache_k.size(2),
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
@ -472,7 +472,7 @@ def falcon_attention_forward(
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(
batch_size,
self.num_heads,
num_kv_heads,
self.head_dim,
kv_length,
max_cache_length,
@ -485,8 +485,8 @@ def falcon_attention_forward(
value_layer = new_value_states
query_layer = query_layer.view(batch_size * self.num_heads, -1, self.head_dim)
key_layer = key_layer.view(batch_size * self.num_heads, -1, self.head_dim)
value_layer = value_layer.view(batch_size * self.num_heads, -1, self.head_dim)
key_layer = key_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
value_layer = value_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)