LLM: add memory optimization for llama. (#10592)

* add initial memory optimization.

* fix logic.

* fix logic,

* remove env var check in mlp split.
This commit is contained in:
Cengguang Zhang 2024-04-02 09:07:50 +08:00 committed by GitHub
parent 01f491757a
commit e567956121
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 13 additions and 7 deletions

View file

@ -543,13 +543,13 @@ class LowBitLinear(nn.Linear):
def forward(self, x: torch.Tensor):
# empty cache before and after lm_head at first token when input > 1024
# on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
if self.device is None:
self.device = get_xpu_device_type(self.weight.data)
# TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary
# TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary
self.low_memory_mode = \
self.low_memory_mode and \
(self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
# we should check both self.training and torch.is_inference_mode_enabled().
is_training = self.training and not torch.is_inference_mode_enabled()

View file

@ -83,7 +83,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
_ipex_version = None
@ -186,7 +186,11 @@ def llama_mlp_forward(
hidden_states = attn_output.view(x.shape)
return hidden_states
else:
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
a = self.act_fn(self.gate_proj(x))
b = self.up_proj(x)
c = a * b
del a, b
out = self.down_proj(c)
if residual is not None:
return out + residual
else:

View file

@ -72,8 +72,10 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
return int(os.environ["BIGDL_QUANTIZE_KV_CACHE"]) == 1
if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None:
return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
else:
return x.device.type == 'xpu' and kv_cache_device_check(x) \
and hasattr(linear, "qtype") and \