LLM: add memory optimization for llama. (#10592)
* add initial memory optimization. * fix logic. * fix logic, * remove env var check in mlp split.
This commit is contained in:
parent
01f491757a
commit
e567956121
3 changed files with 13 additions and 7 deletions
|
|
@ -543,13 +543,13 @@ class LowBitLinear(nn.Linear):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# empty cache before and after lm_head at first token when input > 1024
|
# empty cache before and after lm_head at first token when input > 1024
|
||||||
# on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
|
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
self.device = get_xpu_device_type(self.weight.data)
|
self.device = get_xpu_device_type(self.weight.data)
|
||||||
# TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary
|
# TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary
|
||||||
self.low_memory_mode = \
|
self.low_memory_mode = \
|
||||||
self.low_memory_mode and \
|
self.low_memory_mode and \
|
||||||
(self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
|
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
||||||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||||
# we should check both self.training and torch.is_inference_mode_enabled().
|
# we should check both self.training and torch.is_inference_mode_enabled().
|
||||||
is_training = self.training and not torch.is_inference_mode_enabled()
|
is_training = self.training and not torch.is_inference_mode_enabled()
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
n_rep, slen, head_dim)
|
n_rep, slen, head_dim)
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
_ipex_version = None
|
_ipex_version = None
|
||||||
|
|
@ -186,7 +186,11 @@ def llama_mlp_forward(
|
||||||
hidden_states = attn_output.view(x.shape)
|
hidden_states = attn_output.view(x.shape)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
else:
|
else:
|
||||||
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
a = self.act_fn(self.gate_proj(x))
|
||||||
|
b = self.up_proj(x)
|
||||||
|
c = a * b
|
||||||
|
del a, b
|
||||||
|
out = self.down_proj(c)
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
return out + residual
|
return out + residual
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -72,8 +72,10 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||||
|
|
||||||
|
|
||||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||||
return int(os.environ["BIGDL_QUANTIZE_KV_CACHE"]) == 1
|
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
||||||
|
elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None:
|
||||||
|
return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
|
||||||
else:
|
else:
|
||||||
return x.device.type == 'xpu' and kv_cache_device_check(x) \
|
return x.device.type == 'xpu' and kv_cache_device_check(x) \
|
||||||
and hasattr(linear, "qtype") and \
|
and hasattr(linear, "qtype") and \
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue