LLM: fix kv cache issue of bloom and falcon. (#9029)

This commit is contained in:
Cengguang Zhang 2023-09-21 18:12:20 +08:00 committed by GitHub
parent bf51ec40b2
commit 868511cf02
3 changed files with 17 additions and 10 deletions

View file

@ -181,7 +181,7 @@ def optimize(model):
convert_forward(model, convert_forward(model,
module.GPTJAttention, module.GPTJAttention,
gptj_attention_forward) gptj_attention_forward)
elif "bloom" in model.config._name_or_path: elif "bloom" in model.config.model_type:
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.bloom import bloom_attention_forward from bigdl.llm.transformers.models.bloom import bloom_attention_forward
@ -189,17 +189,18 @@ def optimize(model):
module.BloomAttention, module.BloomAttention,
bloom_attention_forward bloom_attention_forward
) )
elif "falcon" in model.config._name_or_path: elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
if "RWForCausalLM" in model.config.architectures: if "RWForCausalLM" in model.config.architectures:
if hasattr(model.config, "multi_query"): if hasattr(model.config, "multi_query"):
# falcon-7b # falcon-7b need to check performance drop after kv cache support.
from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
convert_forward(model, # convert_forward(model,
module.Attention, # module.Attention,
rw_attention_forward_7b # rw_attention_forward_7b
) # )
pass
else: else:
# falcon-40b # falcon-40b
from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b

View file

@ -96,6 +96,8 @@ def bloom_attention_forward(
self.head_dim self.head_dim
) )
_, _, kv_length = key_layer.shape _, _, kv_length = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-1]
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim) key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)

View file

@ -86,7 +86,8 @@ def rw_attention_forward_7b(
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len) query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
_, kv_length, _ = key_layer.shape _, kv_length, _ = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-2]
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
@ -266,6 +267,8 @@ def rw_attention_forward_40b(
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len) query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
_, kv_length, _ = key_layer.shape _, kv_length, _ = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-2]
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
@ -439,7 +442,8 @@ def falcon_attention_forward(
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
_, kv_length, _ = key_layer.shape _, kv_length, _ = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-2]
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)