LLM: reduce GPU memory for optimize_model=True (#8965)
* reduce gpu memory for llama & chatglm * change to device type
This commit is contained in:
parent
be29c75c18
commit
dd57623650
3 changed files with 6 additions and 0 deletions
|
|
@ -67,6 +67,8 @@ def attention_fn(
|
||||||
cache_v = cache_v.permute(1, 2, 0, 3)
|
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,8 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
|
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,8 @@ def llama_attention_forward_4_31(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
self.num_key_value_heads, # Support GQA
|
self.num_key_value_heads, # Support GQA
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue