Fix chatglm2 attention and kv cache (#8924)
* fix chatglm2 attention * fix bf16 bug * make model stateless * add utils * cleanup * fix style
This commit is contained in:
		
							parent
							
								
									b209b8f7b6
								
							
						
					
					
						commit
						25428b22b4
					
				
					 2 changed files with 82 additions and 56 deletions
				
			
		| 
						 | 
				
			
			@ -20,6 +20,7 @@
 | 
			
		|||
import torch
 | 
			
		||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
| 
						 | 
				
			
			@ -145,39 +146,38 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
    # adjust key and value for inference
 | 
			
		||||
    if kv_cache is not None:
 | 
			
		||||
        cache_k, cache_v = kv_cache
 | 
			
		||||
        past_length = cache_k.size(0)
 | 
			
		||||
        cache_k = cache_k.permute(1, 2, 0, 3)
 | 
			
		||||
        cache_v = cache_v.permute(1, 2, 0, 3)
 | 
			
		||||
        past_length = cache_k.size(2)
 | 
			
		||||
 | 
			
		||||
        if past_length + cur_length > self.max_cache_length:
 | 
			
		||||
            self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            self.kv_cache = (torch.empty(batch_size,
 | 
			
		||||
                                         self.num_attention_heads_per_partition,
 | 
			
		||||
                                         self.max_cache_length,
 | 
			
		||||
                                         self.hidden_size_per_attention_head,
 | 
			
		||||
                                         device=device),
 | 
			
		||||
                             torch.empty(batch_size,
 | 
			
		||||
                                         self.num_attention_heads_per_partition,
 | 
			
		||||
                                         self.max_cache_length,
 | 
			
		||||
                                         self.hidden_size_per_attention_head,
 | 
			
		||||
                                         device=device))
 | 
			
		||||
            self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3)
 | 
			
		||||
            self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3)
 | 
			
		||||
        self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
 | 
			
		||||
        self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
			
		||||
                                                       self.num_attention_heads_per_partition,
 | 
			
		||||
                                                       self.hidden_size_per_attention_head,
 | 
			
		||||
                                                       past_length,
 | 
			
		||||
                                                       max_cache_length,
 | 
			
		||||
                                                       dtype=query_layer.dtype,
 | 
			
		||||
                                                       device=device)
 | 
			
		||||
            new_cache_k[:] = cache_k
 | 
			
		||||
            new_cache_v[:] = cache_v
 | 
			
		||||
            cache_k = new_cache_k
 | 
			
		||||
            cache_v = new_cache_v
 | 
			
		||||
 | 
			
		||||
        key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :]
 | 
			
		||||
        value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :]
 | 
			
		||||
        key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
 | 
			
		||||
 | 
			
		||||
    elif use_cache:
 | 
			
		||||
        self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
 | 
			
		||||
 | 
			
		||||
        max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
 | 
			
		||||
            + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition,
 | 
			
		||||
                                     self.max_cache_length, self.hidden_size_per_attention_head,
 | 
			
		||||
                                     device=device),
 | 
			
		||||
                         torch.empty(batch_size, self.num_attention_heads_per_partition,
 | 
			
		||||
                                     self.max_cache_length, self.hidden_size_per_attention_head,
 | 
			
		||||
                                     device=device))
 | 
			
		||||
        self.kv_cache[0][:, :, :cur_length, :] = key_layer
 | 
			
		||||
        self.kv_cache[1][:, :, :cur_length, :] = value_layer
 | 
			
		||||
        key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition,
 | 
			
		||||
                                                 self.hidden_size_per_attention_head, cur_length,
 | 
			
		||||
                                                 max_cache_length,
 | 
			
		||||
                                                 dtype=query_layer.dtype, device=device)
 | 
			
		||||
        key_cache[:] = key_layer
 | 
			
		||||
        value_cache[:] = value_layer
 | 
			
		||||
        key_layer = key_cache
 | 
			
		||||
        value_layer = value_cache
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        kv_cache = (key_layer, value_layer)
 | 
			
		||||
| 
						 | 
				
			
			@ -204,36 +204,14 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
 | 
			
		|||
    if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
 | 
			
		||||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
        if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
 | 
			
		||||
 | 
			
		||||
            if torch.is_autocast_cpu_enabled():
 | 
			
		||||
                attention_mask = torch.ones(query_layer.shape[2],
 | 
			
		||||
                                            key_layer.shape[2],
 | 
			
		||||
                                            dtype=torch.bool).tril(diagonal=0)
 | 
			
		||||
                attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
 | 
			
		||||
                attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                                 key_layer,
 | 
			
		||||
                                                                                 value_layer,
 | 
			
		||||
                                                                                 attention_mask,
 | 
			
		||||
                                                                                 is_causal=False)
 | 
			
		||||
            else:
 | 
			
		||||
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                                 key_layer,
 | 
			
		||||
                                                                                 value_layer,
 | 
			
		||||
                                                                                 attention_mask,
 | 
			
		||||
                                                                                 is_causal=True)
 | 
			
		||||
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                             key_layer,
 | 
			
		||||
                                                                             value_layer,
 | 
			
		||||
                                                                             attention_mask,
 | 
			
		||||
                                                                             is_causal=True)
 | 
			
		||||
        else:
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
 | 
			
		||||
 | 
			
		||||
            if torch.is_autocast_cpu_enabled():
 | 
			
		||||
                query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
 | 
			
		||||
                attention_mask = ~attention_mask
 | 
			
		||||
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                             key_layer,
 | 
			
		||||
                                                                             value_layer,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										48
									
								
								python/llm/src/bigdl/llm/transformers/models/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								python/llm/src/bigdl/llm/transformers/models/utils.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,48 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
			
		||||
    key_cache_storage = torch.empty(batch_size, num_heads,
 | 
			
		||||
                                    max_length, head_dim,
 | 
			
		||||
                                    dtype=dtype, device=device)
 | 
			
		||||
    value_cache_storage = torch.empty(batch_size, num_heads,
 | 
			
		||||
                                      max_length, head_dim,
 | 
			
		||||
                                      dtype=dtype, device=device)
 | 
			
		||||
 | 
			
		||||
    key_cache = key_cache_storage.as_strided((batch_size, num_heads,
 | 
			
		||||
                                             current_length, head_dim),
 | 
			
		||||
                                             key_cache_storage.stride(),
 | 
			
		||||
                                             storage_offset=0)
 | 
			
		||||
    value_cache = value_cache_storage.as_strided((batch_size, num_heads,
 | 
			
		||||
                                                  current_length, head_dim),
 | 
			
		||||
                                                 value_cache_storage.stride(),
 | 
			
		||||
                                                 storage_offset=0)
 | 
			
		||||
    return key_cache, value_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
			
		||||
    new_size = (cache_k.size(0),
 | 
			
		||||
                cache_k.size(1),
 | 
			
		||||
                cache_k.size(2) + key_states.size(2),
 | 
			
		||||
                cache_k.size(3))
 | 
			
		||||
    new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
 | 
			
		||||
    new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
 | 
			
		||||
    new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
 | 
			
		||||
    new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
 | 
			
		||||
    return new_cache_k, new_cache_v
 | 
			
		||||
		Loading…
	
		Reference in a new issue