diff --git a/python/llm/example/gpu/hf-transformers-models/falcon/generate.py b/python/llm/example/gpu/hf-transformers-models/falcon/generate.py index 41113d46..0edeb47c 100644 --- a/python/llm/example/gpu/hf-transformers-models/falcon/generate.py +++ b/python/llm/example/gpu/hf-transformers-models/falcon/generate.py @@ -44,7 +44,6 @@ if __name__ == '__main__': # which convert the relevant layers in the model into INT4 format model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, - optimize_model=False, trust_remote_code=True, use_cache=True) model = model.to('xpu') diff --git a/python/llm/example/gpu/hf-transformers-models/gpt-j/generate.py b/python/llm/example/gpu/hf-transformers-models/gpt-j/generate.py index 28c385dd..fb937216 100644 --- a/python/llm/example/gpu/hf-transformers-models/gpt-j/generate.py +++ b/python/llm/example/gpu/hf-transformers-models/gpt-j/generate.py @@ -42,7 +42,6 @@ if __name__ == '__main__': # which convert the relevant layers in the model into INT4 format model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, - optimize_model=False, trust_remote_code=True, use_cache=True) model = model.to('xpu') diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py index 71a4e9de..298654f2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -26,7 +26,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from bigdl.llm.utils.common import invalidInputError -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -70,10 +70,8 @@ def baichuan_attention_forward_7b( cache_k = past_key_value[0] cache_v = past_key_value[1] if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new - new_cache_k, new_cache_v = create_kv_cache(bsz, + new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, self.head_dim, cache_k.size(2), @@ -89,13 +87,13 @@ def baichuan_attention_forward_7b( elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) new_key_states[:] = key_states new_value_states[:] = value_states key_states = new_key_states @@ -170,10 +168,8 @@ def baichuan_attention_forward_13b( cache_k = past_key_value[0] cache_v = past_key_value[1] if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new - new_cache_k, new_cache_v = create_kv_cache(bsz, + new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, self.head_dim, cache_k.size(2), @@ -189,13 +185,13 @@ def baichuan_attention_forward_13b( elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) new_key_states[:] = key_states new_value_states[:] = value_states key_states = new_key_states diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 64dc2532..08d392e8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -26,7 +26,7 @@ from torch import nn from torch.nn import functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from bigdl.llm.utils.common import invalidInputError -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from transformers.utils import logging, ContextManagers logger = logging.get_logger(__name__) @@ -82,10 +82,8 @@ def baichuan_attention_forward_7b( cache_k = past_key_value[0] cache_v = past_key_value[1] if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new - new_cache_k, new_cache_v = create_kv_cache(bsz, + new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, self.head_dim, cache_k.size(2), @@ -101,13 +99,13 @@ def baichuan_attention_forward_7b( elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) new_key_states[:] = key_states new_value_states[:] = value_states key_states = new_key_states @@ -182,7 +180,7 @@ def baichuan_attention_forward_13b( if device.type == 'xpu': torch.xpu.empty_cache() # allocate new - new_cache_k, new_cache_v = create_kv_cache(bsz, + new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads, self.head_dim, cache_k.size(2), @@ -198,13 +196,13 @@ def baichuan_attention_forward_13b( elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) new_key_states[:] = key_states new_value_states[:] = value_states key_states = new_key_states diff --git a/python/llm/src/bigdl/llm/transformers/models/bloom.py b/python/llm/src/bigdl/llm/transformers/models/bloom.py index d06f784a..e44a26c8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/bloom.py +++ b/python/llm/src/bigdl/llm/transformers/models/bloom.py @@ -37,7 +37,7 @@ from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch.nn import functional as F -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -107,10 +107,8 @@ def bloom_attention_forward( cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new - new_cache_k, new_cache_v = create_kv_cache( + new_cache_k, new_cache_v = extend_kv_cache( batch_size, self.num_heads, self.head_dim, @@ -128,7 +126,7 @@ def bloom_attention_forward( elif use_cache: max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache( + new_key_states, new_value_states = init_kv_cache( batch_size, self.num_heads, self.head_dim, diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm.py b/python/llm/src/bigdl/llm/transformers/models/chatglm.py index 89525697..4f773772 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm.py @@ -22,7 +22,7 @@ import torch import torch.utils.checkpoint import torch.nn.functional as F from typing import Optional, Tuple -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache def rotate_half(x): @@ -67,10 +67,8 @@ def attention_fn( cache_v = cache_v.permute(1, 2, 0, 3) past_length = cache_k.size(2) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_cache_k, new_cache_v = create_kv_cache(batch_size, + new_cache_k, new_cache_v = extend_kv_cache(batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, past_length, @@ -84,10 +82,10 @@ def attention_fn( elif use_cache: max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH - key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, cur_length, - max_cache_length, - dtype=query_layer.dtype, device=device) + key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, cur_length, + max_cache_length, + dtype=query_layer.dtype, device=device) key_cache[:] = key_layer value_cache[:] = value_layer key_layer = key_cache diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 5de558e9..7dc90f86 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -20,7 +20,7 @@ import torch from typing import Optional, Tuple, Union, List, Callable, Dict, Any import torch.nn.functional as F -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -151,10 +151,8 @@ def chatglm2_attention_forward_8eb45c( past_length = cache_k.size(2) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_cache_k, new_cache_v = create_kv_cache(batch_size, + new_cache_k, new_cache_v = extend_kv_cache(batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, past_length, @@ -172,10 +170,10 @@ def chatglm2_attention_forward_8eb45c( max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH - key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, cur_length, - max_cache_length, - dtype=query_layer.dtype, device=device) + key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, cur_length, + max_cache_length, + dtype=query_layer.dtype, device=device) key_cache[:] = key_layer value_cache[:] = value_layer key_layer = key_cache diff --git a/python/llm/src/bigdl/llm/transformers/models/falcon.py b/python/llm/src/bigdl/llm/transformers/models/falcon.py index dc66fed3..3a2c565d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/falcon.py +++ b/python/llm/src/bigdl/llm/transformers/models/falcon.py @@ -38,7 +38,7 @@ from typing import Optional, Tuple import torch from torch.nn import functional as F from bigdl.llm.utils.common import invalidInputError -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -98,10 +98,8 @@ def rw_attention_forward_7b( cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new - new_cache_k, new_cache_v = create_kv_cache( + new_cache_k, new_cache_v = extend_kv_cache( batch_size, self.num_kv, self.head_dim, @@ -119,7 +117,7 @@ def rw_attention_forward_7b( elif use_cache: max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache( + new_key_states, new_value_states = init_kv_cache( batch_size, self.num_kv, self.head_dim, @@ -280,7 +278,7 @@ def rw_attention_forward_40b( cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): # allocate new - new_cache_k, new_cache_v = create_kv_cache( + new_cache_k, new_cache_v = extend_kv_cache( batch_size, self.num_heads, self.head_dim, @@ -298,7 +296,7 @@ def rw_attention_forward_40b( elif use_cache: max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache( + new_key_states, new_value_states = init_kv_cache( batch_size, self.num_heads, self.head_dim, @@ -454,7 +452,7 @@ def falcon_attention_forward( cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): # allocate new - new_cache_k, new_cache_v = create_kv_cache( + new_cache_k, new_cache_v = extend_kv_cache( batch_size, self.num_heads, self.head_dim, @@ -472,7 +470,7 @@ def falcon_attention_forward( elif use_cache: max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache( + new_key_states, new_value_states = init_kv_cache( batch_size, self.num_heads, self.head_dim, diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index 8e390fca..e904a520 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -19,8 +19,8 @@ import torch from typing import Optional, Tuple, Union -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache, \ - apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ + apply_rotary_pos_emb, append_kv_cache from transformers.utils.import_utils import is_torch_fx_proxy @@ -144,9 +144,7 @@ def gptj_attention_forward( past_length = cache_k.size(2) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() - new_cache_k, new_cache_v = create_kv_cache(batch_size, + new_cache_k, new_cache_v = extend_kv_cache(batch_size, self.num_attention_heads, self.head_dim, past_length, @@ -160,13 +158,13 @@ def gptj_attention_forward( key, value = append_kv_cache(cache_k, cache_v, key, value) elif use_cache: - key_cache, value_cache = create_kv_cache(batch_size, - self.num_attention_heads, - self.head_dim, - kv_seq_len, - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=key.dtype, - device=device) + key_cache, value_cache = init_kv_cache(batch_size, + self.num_attention_heads, + self.head_dim, + kv_seq_len, + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=key.dtype, + device=device) key_cache[:] = key value_cache[:] = value key = key_cache diff --git a/python/llm/src/bigdl/llm/transformers/models/gptneox.py b/python/llm/src/bigdl/llm/transformers/models/gptneox.py index 0d0c16c6..8e31a14a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptneox.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptneox.py @@ -34,7 +34,7 @@ import torch from typing import Optional, Tuple from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -90,10 +90,8 @@ def gptneox_attention_forward( past_key = layer_past[0] past_value = layer_past[1] if past_key.stride()[1] <= past_key.size(2) * past_key.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new - new_past_key, new_past_value = create_kv_cache(bsz, + new_past_key, new_past_value = extend_kv_cache(bsz, self.num_attention_heads, self.head_size, past_key.size(2), @@ -108,13 +106,13 @@ def gptneox_attention_forward( key, value = append_kv_cache(past_key, past_value, key, value) elif use_cache: max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key, new_value = create_kv_cache(bsz, - self.num_attention_heads, - self.head_size, - seq_len, - max_cache_length, - dtype=key.dtype, - device=device) + new_key, new_value = init_kv_cache(bsz, + self.num_attention_heads, + self.head_size, + seq_len, + max_cache_length, + dtype=key.dtype, + device=device) new_key[:] = key new_value[:] = value key = new_key diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index c8b07f63..51ddb2ee 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -37,7 +37,7 @@ from typing import Optional, Tuple import math import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError -from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb @@ -112,10 +112,8 @@ def llama_attention_forward_4_31( cache_k = past_key_value[0] cache_v = past_key_value[1] if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new - new_cache_k, new_cache_v = create_kv_cache(bsz, + new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_key_value_heads, # Support GQA self.head_dim, cache_k.size(2), @@ -131,13 +129,13 @@ def llama_attention_forward_4_31( elif use_cache: max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = create_kv_cache(bsz, - self.num_key_value_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_key_value_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) new_key_states[:] = key_states new_value_states[:] = value_states key_states = new_key_states diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 8d85db74..b47ad8e7 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -18,7 +18,7 @@ import torch from bigdl.llm.utils.common import invalidInputError -def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): +def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): key_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, dtype=dtype, device=device) @@ -27,7 +27,7 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype=dtype, device=device) key_cache = key_cache_storage.as_strided((batch_size, num_heads, - current_length, head_dim), + current_length, head_dim), key_cache_storage.stride(), storage_offset=0) value_cache = value_cache_storage.as_strided((batch_size, num_heads, @@ -37,6 +37,13 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, return key_cache, value_cache +def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): + # empty cache to reduce gpu memory + if device.type == 'xpu': + torch.xpu.empty_cache() + return init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device) + + def append_kv_cache(cache_k, cache_v, key_states, value_states): new_size = (cache_k.size(0), cache_k.size(1),