LLM: add memory optimization for llama. (#10592)
* add initial memory optimization. * fix logic. * fix logic, * remove env var check in mlp split.
This commit is contained in:
parent
01f491757a
commit
e567956121
3 changed files with 13 additions and 7 deletions
|
|
@ -543,13 +543,13 @@ class LowBitLinear(nn.Linear):
|
|||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# empty cache before and after lm_head at first token when input > 1024
|
||||
# on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
|
||||
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
|
||||
if self.device is None:
|
||||
self.device = get_xpu_device_type(self.weight.data)
|
||||
# TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary
|
||||
# TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary
|
||||
self.low_memory_mode = \
|
||||
self.low_memory_mode and \
|
||||
(self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
|
||||
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
||||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||
# we should check both self.training and torch.is_inference_mode_enabled().
|
||||
is_training = self.training and not torch.is_inference_mode_enabled()
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||
n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||
|
||||
|
||||
_ipex_version = None
|
||||
|
|
@ -186,7 +186,11 @@ def llama_mlp_forward(
|
|||
hidden_states = attn_output.view(x.shape)
|
||||
return hidden_states
|
||||
else:
|
||||
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
a = self.act_fn(self.gate_proj(x))
|
||||
b = self.up_proj(x)
|
||||
c = a * b
|
||||
del a, b
|
||||
out = self.down_proj(c)
|
||||
if residual is not None:
|
||||
return out + residual
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -72,8 +72,10 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
|||
|
||||
|
||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||
return int(os.environ["BIGDL_QUANTIZE_KV_CACHE"]) == 1
|
||||
if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
||||
elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None:
|
||||
return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
|
||||
else:
|
||||
return x.device.type == 'xpu' and kv_cache_device_check(x) \
|
||||
and hasattr(linear, "qtype") and \
|
||||
|
|
|
|||
Loading…
Reference in a new issue