optimize qewn2 memory (#11535)
This commit is contained in:
		
							parent
							
								
									2929eb262e
								
							
						
					
					
						commit
						99b2802d3b
					
				
					 1 changed files with 3 additions and 13 deletions
				
			
		| 
						 | 
				
			
			@ -37,6 +37,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple, Union, List
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -55,7 +56,7 @@ from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repea
 | 
			
		|||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 | 
			
		||||
from transformers.cache_utils import Cache, DynamicCache
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
from transformers import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -339,20 +340,9 @@ def merge_qkv(module: torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        del module.q_proj, module.k_proj, module.v_proj
 | 
			
		||||
 | 
			
		||||
        # Qwen2 uses pre-computed rope table to accelerate rope,
 | 
			
		||||
        # original `cos_cached` and `sin_cached` are added by `register_buffer`,
 | 
			
		||||
        # so they will move to xpu during `model.to('xpu')`.
 | 
			
		||||
        # But gpu fuse kernel doesn't need this rope table, only cpu needs them,
 | 
			
		||||
        # so delete them then add them with `=`, so that they will be pinned on CPU,
 | 
			
		||||
        # this can save about 0.5GB gpu memory usage when running Qwen2
 | 
			
		||||
        if hasattr(module.rotary_emb, "cos_cached"):
 | 
			
		||||
            cos_cached = module.rotary_emb.cos_cached
 | 
			
		||||
        if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
 | 
			
		||||
            del module.rotary_emb.cos_cached
 | 
			
		||||
            module.rotary_emb.cos_cached = cos_cached
 | 
			
		||||
        if hasattr(module.rotary_emb, "sin_cached"):
 | 
			
		||||
            sin_cached = module.rotary_emb.sin_cached
 | 
			
		||||
            del module.rotary_emb.sin_cached
 | 
			
		||||
            module.rotary_emb.sin_cached = sin_cached
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_mlp(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue