LLM: unify memory optimization env variables. (#11549)
* LLM: unify memory optimization env variables. * fix comments.
This commit is contained in:
parent
51f2effb05
commit
70ab1a6f1a
3 changed files with 9 additions and 3 deletions
|
|
@ -327,9 +327,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
optimize_lm_head = False
|
optimize_lm_head = False
|
||||||
if is_lm_head(name, model_config, out_features):
|
if is_lm_head(name, model_config, out_features):
|
||||||
model_type = getattr(model_config, "model_type", None)
|
model_type = getattr(model_config, "model_type", None)
|
||||||
if model_type in ["gptj", "llama", "qwen2"] and \
|
if model_type in ["gptj", "llama", "qwen2"]:
|
||||||
os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1":
|
if os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) is not None:
|
||||||
optimize_lm_head = True
|
optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1"
|
||||||
|
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||||
|
optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
new_linear = None
|
new_linear = None
|
||||||
is_gptq = is_gptq_linear(module)
|
is_gptq = is_gptq_linear(module)
|
||||||
|
|
|
||||||
|
|
@ -286,6 +286,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
||||||
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
||||||
|
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||||
|
return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
|
||||||
elif query_states.dtype == torch.float16 and \
|
elif query_states.dtype == torch.float16 and \
|
||||||
query_states.shape[2] >= 6800:
|
query_states.shape[2] >= 6800:
|
||||||
# split tensor for memory block limitation
|
# split tensor for memory block limitation
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
||||||
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
||||||
|
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||||
|
return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
|
||||||
elif query_states.dtype == torch.float16 and \
|
elif query_states.dtype == torch.float16 and \
|
||||||
query_states.shape[2] >= 6300:
|
query_states.shape[2] >= 6300:
|
||||||
# split tensor for memory block limitation
|
# split tensor for memory block limitation
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue