diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index ba9474e7..6f05cb59 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -181,7 +181,7 @@ def optimize(model): convert_forward(model, module.GPTJAttention, gptj_attention_forward) - elif "bloom" in model.config._name_or_path: + elif "bloom" in model.config.model_type: modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.bloom import bloom_attention_forward @@ -189,17 +189,18 @@ def optimize(model): module.BloomAttention, bloom_attention_forward ) - elif "falcon" in model.config._name_or_path: + elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type: modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) if "RWForCausalLM" in model.config.architectures: if hasattr(model.config, "multi_query"): - # falcon-7b - from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b - convert_forward(model, - module.Attention, - rw_attention_forward_7b - ) + # falcon-7b need to check performance drop after kv cache support. + # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b + # convert_forward(model, + # module.Attention, + # rw_attention_forward_7b + # ) + pass else: # falcon-40b from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b diff --git a/python/llm/src/bigdl/llm/transformers/models/bloom.py b/python/llm/src/bigdl/llm/transformers/models/bloom.py index f3e08cba..d06f784a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/bloom.py +++ b/python/llm/src/bigdl/llm/transformers/models/bloom.py @@ -96,6 +96,8 @@ def bloom_attention_forward( self.head_dim ) _, _, kv_length = key_layer.shape + if layer_past is not None: + kv_length += layer_past[0].shape[-1] query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim) diff --git a/python/llm/src/bigdl/llm/transformers/models/falcon.py b/python/llm/src/bigdl/llm/transformers/models/falcon.py index 0b8ef9c4..dc66fed3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/falcon.py +++ b/python/llm/src/bigdl/llm/transformers/models/falcon.py @@ -86,7 +86,8 @@ def rw_attention_forward_7b( query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len) _, kv_length, _ = key_layer.shape - + if layer_past is not None: + kv_length += layer_past[0].shape[-2] query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim) @@ -266,6 +267,8 @@ def rw_attention_forward_40b( query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len) _, kv_length, _ = key_layer.shape + if layer_past is not None: + kv_length += layer_past[0].shape[-2] query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim) @@ -439,7 +442,8 @@ def falcon_attention_forward( query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) _, kv_length, _ = key_layer.shape - + if layer_past is not None: + kv_length += layer_past[0].shape[-2] query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)