update falcon attention forward. (#9796)
This commit is contained in:
parent
a5e5c3daec
commit
d299f108d0
2 changed files with 16 additions and 15 deletions
|
|
@ -530,7 +530,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
if "RWForCausalLM" in model.config.architectures:
|
||||
if hasattr(model.config, "multi_query"):
|
||||
if model.config.hidden_size == 4544:
|
||||
# falcon-7b need to check performance drop after kv cache support.
|
||||
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
|
||||
# convert_forward(model,
|
||||
|
|
@ -546,12 +546,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
rw_attention_forward_40b
|
||||
)
|
||||
elif "FalconForCausalLM" in model.config.architectures:
|
||||
# falcon-180b
|
||||
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
|
||||
convert_forward(model,
|
||||
module.FalconAttention,
|
||||
falcon_attention_forward
|
||||
)
|
||||
if model.config.hidden_size != 4544:
|
||||
# falcon-180b and new falcon-40b
|
||||
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
|
||||
convert_forward(model,
|
||||
module.FalconAttention,
|
||||
falcon_attention_forward
|
||||
)
|
||||
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
|
||||
# baichuan2
|
||||
if model.config.hidden_size == 4096:
|
||||
|
|
|
|||
|
|
@ -443,18 +443,18 @@ def falcon_attention_forward(
|
|||
if layer_past is not None:
|
||||
kv_length += layer_past[0].shape[-2]
|
||||
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
device = hidden_states.device
|
||||
if layer_past is not None:
|
||||
# reuse k, v, self_attention
|
||||
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
num_kv_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
|
|
@ -472,7 +472,7 @@ def falcon_attention_forward(
|
|||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = init_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
num_kv_heads,
|
||||
self.head_dim,
|
||||
kv_length,
|
||||
max_cache_length,
|
||||
|
|
@ -485,8 +485,8 @@ def falcon_attention_forward(
|
|||
value_layer = new_value_states
|
||||
|
||||
query_layer = query_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
|
||||
_, kv_length, _ = key_layer.shape
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
|
|
|
|||
Loading…
Reference in a new issue