update falcon attention forward. (#9796)
This commit is contained in:
		
							parent
							
								
									a5e5c3daec
								
							
						
					
					
						commit
						d299f108d0
					
				
					 2 changed files with 16 additions and 15 deletions
				
			
		| 
						 | 
				
			
			@ -530,7 +530,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        if "RWForCausalLM" in model.config.architectures:
 | 
			
		||||
            if hasattr(model.config, "multi_query"):
 | 
			
		||||
            if model.config.hidden_size == 4544:
 | 
			
		||||
                # falcon-7b need to check performance drop after kv cache support.
 | 
			
		||||
                # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
 | 
			
		||||
                # convert_forward(model,
 | 
			
		||||
| 
						 | 
				
			
			@ -546,12 +546,13 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                                rw_attention_forward_40b
 | 
			
		||||
                                )
 | 
			
		||||
        elif "FalconForCausalLM" in model.config.architectures:
 | 
			
		||||
            # falcon-180b
 | 
			
		||||
            from bigdl.llm.transformers.models.falcon import falcon_attention_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.FalconAttention,
 | 
			
		||||
                            falcon_attention_forward
 | 
			
		||||
                            )
 | 
			
		||||
            if model.config.hidden_size != 4544:
 | 
			
		||||
                # falcon-180b and new falcon-40b
 | 
			
		||||
                from bigdl.llm.transformers.models.falcon import falcon_attention_forward
 | 
			
		||||
                convert_forward(model,
 | 
			
		||||
                                module.FalconAttention,
 | 
			
		||||
                                falcon_attention_forward
 | 
			
		||||
                                )
 | 
			
		||||
    elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
 | 
			
		||||
        # baichuan2
 | 
			
		||||
        if model.config.hidden_size == 4096:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -443,18 +443,18 @@ def falcon_attention_forward(
 | 
			
		|||
    if layer_past is not None:
 | 
			
		||||
        kv_length += layer_past[0].shape[-2]
 | 
			
		||||
    query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    if layer_past is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                batch_size,
 | 
			
		||||
                self.num_heads,
 | 
			
		||||
                num_kv_heads,
 | 
			
		||||
                self.head_dim,
 | 
			
		||||
                cache_k.size(2),
 | 
			
		||||
                kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
| 
						 | 
				
			
			@ -472,7 +472,7 @@ def falcon_attention_forward(
 | 
			
		|||
        max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        new_key_states, new_value_states = init_kv_cache(
 | 
			
		||||
            batch_size,
 | 
			
		||||
            self.num_heads,
 | 
			
		||||
            num_kv_heads,
 | 
			
		||||
            self.head_dim,
 | 
			
		||||
            kv_length,
 | 
			
		||||
            max_cache_length,
 | 
			
		||||
| 
						 | 
				
			
			@ -485,8 +485,8 @@ def falcon_attention_forward(
 | 
			
		|||
        value_layer = new_value_states
 | 
			
		||||
 | 
			
		||||
    query_layer = query_layer.view(batch_size * self.num_heads, -1, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.view(batch_size * self.num_heads, -1, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size * self.num_heads, -1, self.head_dim)
 | 
			
		||||
    key_layer = key_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
 | 
			
		||||
    value_layer = value_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
 | 
			
		||||
    _, kv_length, _ = key_layer.shape
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        present = (key_layer, value_layer)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue