diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm.py b/python/llm/src/bigdl/llm/transformers/models/chatglm.py index 6c1a0a8a..89525697 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm.py @@ -67,6 +67,8 @@ def attention_fn( cache_v = cache_v.permute(1, 2, 0, 3) past_length = cache_k.size(2) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if device.type == 'xpu': + torch.xpu.empty_cache() max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH new_cache_k, new_cache_v = create_kv_cache(batch_size, self.num_attention_heads_per_partition, diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index d43452cb..5de558e9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -151,6 +151,8 @@ def chatglm2_attention_forward_8eb45c( past_length = cache_k.size(2) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if device.type == 'xpu': + torch.xpu.empty_cache() max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH new_cache_k, new_cache_v = create_kv_cache(batch_size, self.num_attention_heads_per_partition, diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index b9b94af3..415ca4e0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -129,6 +129,8 @@ def llama_attention_forward_4_31( cache_k = past_key_value[0] cache_v = past_key_value[1] if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if device.type == 'xpu': + torch.xpu.empty_cache() # allocate new new_cache_k, new_cache_v = create_kv_cache(bsz, self.num_key_value_heads, # Support GQA