From e567956121012103ba3d7aa7ae901700830b96af Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Tue, 2 Apr 2024 09:07:50 +0800 Subject: [PATCH] LLM: add memory optimization for llama. (#10592) * add initial memory optimization. * fix logic. * fix logic, * remove env var check in mlp split. --- python/llm/src/ipex_llm/transformers/low_bit_linear.py | 6 +++--- python/llm/src/ipex_llm/transformers/models/llama.py | 8 ++++++-- python/llm/src/ipex_llm/transformers/models/utils.py | 6 ++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 797abcb3..91b8cda6 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -543,13 +543,13 @@ class LowBitLinear(nn.Linear): def forward(self, x: torch.Tensor): # empty cache before and after lm_head at first token when input > 1024 - # on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time. + # on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time. if self.device is None: self.device = get_xpu_device_type(self.weight.data) - # TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary + # TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary self.low_memory_mode = \ self.low_memory_mode and \ - (self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1") + (self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1") # Due to inconsistent training status in some models like Baichuan-7b-Chat, # we should check both self.training and torch.is_inference_mode_enabled(). is_training = self.training and not torch.is_inference_mode_enabled() diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 1221012e..862e09b1 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -83,7 +83,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) _ipex_version = None @@ -186,7 +186,11 @@ def llama_mlp_forward( hidden_states = attn_output.view(x.shape) return hidden_states else: - out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + a = self.act_fn(self.gate_proj(x)) + b = self.up_proj(x) + c = a * b + del a, b + out = self.down_proj(c) if residual is not None: return out + residual else: diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 222634ef..da4081ec 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -72,8 +72,10 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: - if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: - return int(os.environ["BIGDL_QUANTIZE_KV_CACHE"]) == 1 + if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: + return os.environ["IPEX_LLM_LOW_MEM"] == "1" + elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None: + return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1" else: return x.device.type == 'xpu' and kv_cache_device_check(x) \ and hasattr(linear, "qtype") and \