optimize qewn2 memory (#11535)
This commit is contained in:
parent
2929eb262e
commit
99b2802d3b
1 changed files with 3 additions and 13 deletions
|
|
@ -37,6 +37,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
|
|
||||||
|
|
@ -55,7 +56,7 @@ from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repea
|
||||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache
|
||||||
from transformers import logging
|
from transformers import logging
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -339,20 +340,9 @@ def merge_qkv(module: torch.nn.Module):
|
||||||
|
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
del module.q_proj, module.k_proj, module.v_proj
|
||||||
|
|
||||||
# Qwen2 uses pre-computed rope table to accelerate rope,
|
if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
|
||||||
# original `cos_cached` and `sin_cached` are added by `register_buffer`,
|
|
||||||
# so they will move to xpu during `model.to('xpu')`.
|
|
||||||
# But gpu fuse kernel doesn't need this rope table, only cpu needs them,
|
|
||||||
# so delete them then add them with `=`, so that they will be pinned on CPU,
|
|
||||||
# this can save about 0.5GB gpu memory usage when running Qwen2
|
|
||||||
if hasattr(module.rotary_emb, "cos_cached"):
|
|
||||||
cos_cached = module.rotary_emb.cos_cached
|
|
||||||
del module.rotary_emb.cos_cached
|
del module.rotary_emb.cos_cached
|
||||||
module.rotary_emb.cos_cached = cos_cached
|
|
||||||
if hasattr(module.rotary_emb, "sin_cached"):
|
|
||||||
sin_cached = module.rotary_emb.sin_cached
|
|
||||||
del module.rotary_emb.sin_cached
|
del module.rotary_emb.sin_cached
|
||||||
module.rotary_emb.sin_cached = sin_cached
|
|
||||||
|
|
||||||
|
|
||||||
def padding_mlp(module: torch.nn.Module):
|
def padding_mlp(module: torch.nn.Module):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue