LLM: reduce GPU memory for optimize_model=True (#8965)

* reduce gpu memory for llama & chatglm

* change to device type
This commit is contained in:
Ruonan Wang 2023-09-13 17:27:09 +08:00 committed by GitHub
parent be29c75c18
commit dd57623650
3 changed files with 6 additions and 0 deletions

View file

@ -67,6 +67,8 @@ def attention_fn(
cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads_per_partition,

View file

@ -151,6 +151,8 @@ def chatglm2_attention_forward_8eb45c(
past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads_per_partition,

View file

@ -129,6 +129,8 @@ def llama_attention_forward_4_31(
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_key_value_heads, # Support GQA