LLM: Fix empty cache (#9024)

* fix

* fix

* update example
This commit is contained in:
Ruonan Wang 2023-09-21 17:16:07 +08:00 committed by GitHub
parent 714884414e
commit bf51ec40b2
12 changed files with 22 additions and 4 deletions

View file

@ -42,7 +42,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -46,7 +46,6 @@ if __name__ == '__main__':
# to obtain optimal performance with BigDL-LLM INT4 optimizations # to obtain optimal performance with BigDL-LLM INT4 optimizations
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -70,6 +70,8 @@ def baichuan_attention_forward_7b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads, self.num_heads,
@ -168,6 +170,8 @@ def baichuan_attention_forward_13b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads, self.num_heads,

View file

@ -82,6 +82,8 @@ def baichuan_attention_forward_7b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads, self.num_heads,
@ -177,6 +179,8 @@ def baichuan_attention_forward_13b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads, self.num_heads,

View file

@ -105,6 +105,8 @@ def bloom_attention_forward(
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim) cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = create_kv_cache(
batch_size, batch_size,

View file

@ -67,6 +67,8 @@ def attention_fn(
cache_v = cache_v.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,

View file

@ -151,6 +151,8 @@ def chatglm2_attention_forward_8eb45c(
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,

View file

@ -97,6 +97,8 @@ def rw_attention_forward_7b(
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim) cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = create_kv_cache(
batch_size, batch_size,

View file

@ -144,6 +144,8 @@ def gptj_attention_forward(
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads, self.num_attention_heads,
self.head_dim, self.head_dim,

View file

@ -90,6 +90,8 @@ def gptneox_attention_forward(
past_key = layer_past[0] past_key = layer_past[0]
past_value = layer_past[1] past_value = layer_past[1]
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3): if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_past_key, new_past_value = create_kv_cache(bsz, new_past_key, new_past_value = create_kv_cache(bsz,
self.num_attention_heads, self.num_attention_heads,

View file

@ -112,6 +112,8 @@ def llama_attention_forward_4_31(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_key_value_heads, # Support GQA self.num_key_value_heads, # Support GQA

View file

@ -19,8 +19,6 @@ from bigdl.llm.utils.common import invalidInputError
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
if device.type == 'xpu':
torch.xpu.empty_cache()
key_cache_storage = torch.empty(batch_size, num_heads, key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim, max_length, head_dim,
dtype=dtype, device=device) dtype=dtype, device=device)