update falcon attention forward. (#9796)
This commit is contained in:
parent
a5e5c3daec
commit
d299f108d0
2 changed files with 16 additions and 15 deletions
|
|
@ -530,7 +530,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
if "RWForCausalLM" in model.config.architectures:
|
if "RWForCausalLM" in model.config.architectures:
|
||||||
if hasattr(model.config, "multi_query"):
|
if model.config.hidden_size == 4544:
|
||||||
# falcon-7b need to check performance drop after kv cache support.
|
# falcon-7b need to check performance drop after kv cache support.
|
||||||
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
|
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
|
||||||
# convert_forward(model,
|
# convert_forward(model,
|
||||||
|
|
@ -546,7 +546,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
rw_attention_forward_40b
|
rw_attention_forward_40b
|
||||||
)
|
)
|
||||||
elif "FalconForCausalLM" in model.config.architectures:
|
elif "FalconForCausalLM" in model.config.architectures:
|
||||||
# falcon-180b
|
if model.config.hidden_size != 4544:
|
||||||
|
# falcon-180b and new falcon-40b
|
||||||
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
|
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.FalconAttention,
|
module.FalconAttention,
|
||||||
|
|
|
||||||
|
|
@ -443,18 +443,18 @@ def falcon_attention_forward(
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
kv_length += layer_past[0].shape[-2]
|
kv_length += layer_past[0].shape[-2]
|
||||||
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||||
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
key_layer = key_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||||
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
value_layer = value_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
|
||||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(
|
new_cache_k, new_cache_v = extend_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
|
@ -472,7 +472,7 @@ def falcon_attention_forward(
|
||||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_key_states, new_value_states = init_kv_cache(
|
new_key_states, new_value_states = init_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_heads,
|
num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
kv_length,
|
kv_length,
|
||||||
max_cache_length,
|
max_cache_length,
|
||||||
|
|
@ -485,8 +485,8 @@ def falcon_attention_forward(
|
||||||
value_layer = new_value_states
|
value_layer = new_value_states
|
||||||
|
|
||||||
query_layer = query_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
query_layer = query_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
key_layer = key_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
key_layer = key_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
|
||||||
value_layer = value_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
value_layer = value_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
|
||||||
_, kv_length, _ = key_layer.shape
|
_, kv_length, _ = key_layer.shape
|
||||||
if use_cache:
|
if use_cache:
|
||||||
present = (key_layer, value_layer)
|
present = (key_layer, value_layer)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue