LLM: fix kv cache issue of bloom and falcon. (#9029)
This commit is contained in:
		
							parent
							
								
									bf51ec40b2
								
							
						
					
					
						commit
						868511cf02
					
				
					 3 changed files with 17 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -181,7 +181,7 @@ def optimize(model):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        module.GPTJAttention,
 | 
			
		||||
                        gptj_attention_forward)
 | 
			
		||||
    elif "bloom" in model.config._name_or_path:
 | 
			
		||||
    elif "bloom" in model.config.model_type:
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.bloom import bloom_attention_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -189,17 +189,18 @@ def optimize(model):
 | 
			
		|||
                        module.BloomAttention,
 | 
			
		||||
                        bloom_attention_forward
 | 
			
		||||
                        )
 | 
			
		||||
    elif "falcon" in model.config._name_or_path:
 | 
			
		||||
    elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        if "RWForCausalLM" in model.config.architectures:
 | 
			
		||||
            if hasattr(model.config, "multi_query"):
 | 
			
		||||
                # falcon-7b
 | 
			
		||||
                from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
 | 
			
		||||
                convert_forward(model,
 | 
			
		||||
                                module.Attention,
 | 
			
		||||
                                rw_attention_forward_7b
 | 
			
		||||
                                )
 | 
			
		||||
                # falcon-7b need to check performance drop after kv cache support.
 | 
			
		||||
                # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
 | 
			
		||||
                # convert_forward(model,
 | 
			
		||||
                #                 module.Attention,
 | 
			
		||||
                #                 rw_attention_forward_7b
 | 
			
		||||
                #                 )
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
                # falcon-40b
 | 
			
		||||
                from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,6 +96,8 @@ def bloom_attention_forward(
 | 
			
		|||
        self.head_dim
 | 
			
		||||
    )
 | 
			
		||||
    _, _, kv_length = key_layer.shape
 | 
			
		||||
    if layer_past is not None:
 | 
			
		||||
        kv_length += layer_past[0].shape[-1]
 | 
			
		||||
    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -86,7 +86,8 @@ def rw_attention_forward_7b(
 | 
			
		|||
    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
			
		||||
 | 
			
		||||
    _, kv_length, _ = key_layer.shape
 | 
			
		||||
 | 
			
		||||
    if layer_past is not None:
 | 
			
		||||
        kv_length += layer_past[0].shape[-2]
 | 
			
		||||
    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
 | 
			
		||||
| 
						 | 
				
			
			@ -266,6 +267,8 @@ def rw_attention_forward_40b(
 | 
			
		|||
    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
			
		||||
 | 
			
		||||
    _, kv_length, _ = key_layer.shape
 | 
			
		||||
    if layer_past is not None:
 | 
			
		||||
        kv_length += layer_past[0].shape[-2]
 | 
			
		||||
    query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
| 
						 | 
				
			
			@ -439,7 +442,8 @@ def falcon_attention_forward(
 | 
			
		|||
    query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
 | 
			
		||||
 | 
			
		||||
    _, kv_length, _ = key_layer.shape
 | 
			
		||||
 | 
			
		||||
    if layer_past is not None:
 | 
			
		||||
        kv_length += layer_past[0].shape[-2]
 | 
			
		||||
    query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue