LLM: fix kv cache issue of bloom and falcon. (#9029)
This commit is contained in:
parent
bf51ec40b2
commit
868511cf02
3 changed files with 17 additions and 10 deletions
|
|
@ -181,7 +181,7 @@ def optimize(model):
|
|||
convert_forward(model,
|
||||
module.GPTJAttention,
|
||||
gptj_attention_forward)
|
||||
elif "bloom" in model.config._name_or_path:
|
||||
elif "bloom" in model.config.model_type:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.bloom import bloom_attention_forward
|
||||
|
|
@ -189,17 +189,18 @@ def optimize(model):
|
|||
module.BloomAttention,
|
||||
bloom_attention_forward
|
||||
)
|
||||
elif "falcon" in model.config._name_or_path:
|
||||
elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
if "RWForCausalLM" in model.config.architectures:
|
||||
if hasattr(model.config, "multi_query"):
|
||||
# falcon-7b
|
||||
from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
|
||||
convert_forward(model,
|
||||
module.Attention,
|
||||
rw_attention_forward_7b
|
||||
)
|
||||
# falcon-7b need to check performance drop after kv cache support.
|
||||
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
|
||||
# convert_forward(model,
|
||||
# module.Attention,
|
||||
# rw_attention_forward_7b
|
||||
# )
|
||||
pass
|
||||
else:
|
||||
# falcon-40b
|
||||
from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
|
||||
|
|
|
|||
|
|
@ -96,6 +96,8 @@ def bloom_attention_forward(
|
|||
self.head_dim
|
||||
)
|
||||
_, _, kv_length = key_layer.shape
|
||||
if layer_past is not None:
|
||||
kv_length += layer_past[0].shape[-1]
|
||||
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
|
|
|
|||
|
|
@ -86,7 +86,8 @@ def rw_attention_forward_7b(
|
|||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
|
||||
if layer_past is not None:
|
||||
kv_length += layer_past[0].shape[-2]
|
||||
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
||||
|
|
@ -266,6 +267,8 @@ def rw_attention_forward_40b(
|
|||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
if layer_past is not None:
|
||||
kv_length += layer_past[0].shape[-2]
|
||||
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
|
|
@ -439,7 +442,8 @@ def falcon_attention_forward(
|
|||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
|
||||
if layer_past is not None:
|
||||
kv_length += layer_past[0].shape[-2]
|
||||
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
|
|
|
|||
Loading…
Reference in a new issue