LLM: unify memory optimization env variables. (#11549)
* LLM: unify memory optimization env variables. * fix comments.
This commit is contained in:
		
							parent
							
								
									51f2effb05
								
							
						
					
					
						commit
						70ab1a6f1a
					
				
					 3 changed files with 9 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -327,9 +327,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
            optimize_lm_head = False
 | 
			
		||||
            if is_lm_head(name, model_config, out_features):
 | 
			
		||||
                model_type = getattr(model_config, "model_type", None)
 | 
			
		||||
                if model_type in ["gptj", "llama", "qwen2"] and \
 | 
			
		||||
                        os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1":
 | 
			
		||||
                    optimize_lm_head = True
 | 
			
		||||
                if model_type in ["gptj", "llama", "qwen2"]:
 | 
			
		||||
                    if os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) is not None:
 | 
			
		||||
                        optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1"
 | 
			
		||||
                    elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
 | 
			
		||||
                        optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
 | 
			
		||||
            with init_empty_weights():
 | 
			
		||||
                new_linear = None
 | 
			
		||||
                is_gptq = is_gptq_linear(module)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -286,6 +286,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out
 | 
			
		|||
    if not output_attentions:
 | 
			
		||||
        if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
 | 
			
		||||
            return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
 | 
			
		||||
        elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
 | 
			
		||||
            return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
 | 
			
		||||
        elif query_states.dtype == torch.float16 and \
 | 
			
		||||
                query_states.shape[2] >= 6800:
 | 
			
		||||
            # split tensor for memory block limitation
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -92,6 +92,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out
 | 
			
		|||
    if not output_attentions:
 | 
			
		||||
        if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
 | 
			
		||||
            return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
 | 
			
		||||
        elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
 | 
			
		||||
            return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
 | 
			
		||||
        elif query_states.dtype == torch.float16 and \
 | 
			
		||||
                query_states.shape[2] >= 6300:
 | 
			
		||||
            # split tensor for memory block limitation
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue