LLM: fix kv cache issue of bloom and falcon. (#9029)
This commit is contained in:
parent
bf51ec40b2
commit
868511cf02
3 changed files with 17 additions and 10 deletions
|
|
@ -181,7 +181,7 @@ def optimize(model):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.GPTJAttention,
|
module.GPTJAttention,
|
||||||
gptj_attention_forward)
|
gptj_attention_forward)
|
||||||
elif "bloom" in model.config._name_or_path:
|
elif "bloom" in model.config.model_type:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.bloom import bloom_attention_forward
|
from bigdl.llm.transformers.models.bloom import bloom_attention_forward
|
||||||
|
|
@ -189,17 +189,18 @@ def optimize(model):
|
||||||
module.BloomAttention,
|
module.BloomAttention,
|
||||||
bloom_attention_forward
|
bloom_attention_forward
|
||||||
)
|
)
|
||||||
elif "falcon" in model.config._name_or_path:
|
elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
if "RWForCausalLM" in model.config.architectures:
|
if "RWForCausalLM" in model.config.architectures:
|
||||||
if hasattr(model.config, "multi_query"):
|
if hasattr(model.config, "multi_query"):
|
||||||
# falcon-7b
|
# falcon-7b need to check performance drop after kv cache support.
|
||||||
from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
|
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
|
||||||
convert_forward(model,
|
# convert_forward(model,
|
||||||
module.Attention,
|
# module.Attention,
|
||||||
rw_attention_forward_7b
|
# rw_attention_forward_7b
|
||||||
)
|
# )
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
# falcon-40b
|
# falcon-40b
|
||||||
from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
|
from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,8 @@ def bloom_attention_forward(
|
||||||
self.head_dim
|
self.head_dim
|
||||||
)
|
)
|
||||||
_, _, kv_length = key_layer.shape
|
_, _, kv_length = key_layer.shape
|
||||||
|
if layer_past is not None:
|
||||||
|
kv_length += layer_past[0].shape[-1]
|
||||||
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
|
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,8 @@ def rw_attention_forward_7b(
|
||||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
||||||
|
|
||||||
_, kv_length, _ = key_layer.shape
|
_, kv_length, _ = key_layer.shape
|
||||||
|
if layer_past is not None:
|
||||||
|
kv_length += layer_past[0].shape[-2]
|
||||||
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
||||||
value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
||||||
|
|
@ -266,6 +267,8 @@ def rw_attention_forward_40b(
|
||||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
||||||
|
|
||||||
_, kv_length, _ = key_layer.shape
|
_, kv_length, _ = key_layer.shape
|
||||||
|
if layer_past is not None:
|
||||||
|
kv_length += layer_past[0].shape[-2]
|
||||||
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
|
|
@ -439,7 +442,8 @@ def falcon_attention_forward(
|
||||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||||
|
|
||||||
_, kv_length, _ = key_layer.shape
|
_, kv_length, _ = key_layer.shape
|
||||||
|
if layer_past is not None:
|
||||||
|
kv_length += layer_past[0].shape[-2]
|
||||||
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||||
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||||
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue