LLM: add memory optimization for llama. (#10592)
* add initial memory optimization. * fix logic. * fix logic, * remove env var check in mlp split.
This commit is contained in:
		
							parent
							
								
									01f491757a
								
							
						
					
					
						commit
						e567956121
					
				
					 3 changed files with 13 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -543,13 +543,13 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        # empty cache before and after lm_head at first token when input > 1024
 | 
			
		||||
        # on arc or BIGDL_LOW_MEMORY_MODE is set to 1 at inference time.
 | 
			
		||||
        # on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
 | 
			
		||||
        if self.device is None:
 | 
			
		||||
            self.device = get_xpu_device_type(self.weight.data)
 | 
			
		||||
            # TODO: may remove BIGDL_LOW_MEMORY_MODE here, probably not necessary
 | 
			
		||||
            # TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary
 | 
			
		||||
            self.low_memory_mode = \
 | 
			
		||||
                self.low_memory_mode and \
 | 
			
		||||
                (self.device == "arc" or os.environ.get("BIGDL_LOW_MEMORY_MODE", None) == "1")
 | 
			
		||||
                (self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
 | 
			
		||||
        # Due to inconsistent training status in some models like Baichuan-7b-Chat,
 | 
			
		||||
        # we should check both self.training and torch.is_inference_mode_enabled().
 | 
			
		||||
        is_training = self.training and not torch.is_inference_mode_enabled()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -83,7 +83,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		|||
                                                           n_rep, slen, head_dim)
 | 
			
		||||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_ipex_version = None
 | 
			
		||||
| 
						 | 
				
			
			@ -186,7 +186,11 @@ def llama_mlp_forward(
 | 
			
		|||
            hidden_states = attn_output.view(x.shape)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    else:
 | 
			
		||||
        out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 | 
			
		||||
        a = self.act_fn(self.gate_proj(x))
 | 
			
		||||
        b = self.up_proj(x)
 | 
			
		||||
        c = a * b
 | 
			
		||||
        del a, b
 | 
			
		||||
        out = self.down_proj(c)
 | 
			
		||||
        if residual is not None:
 | 
			
		||||
            return out + residual
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,8 +72,10 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
			
		||||
    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
			
		||||
        return int(os.environ["BIGDL_QUANTIZE_KV_CACHE"]) == 1
 | 
			
		||||
    if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
 | 
			
		||||
        return os.environ["IPEX_LLM_LOW_MEM"] == "1"
 | 
			
		||||
    elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None:
 | 
			
		||||
        return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
 | 
			
		||||
    else:
 | 
			
		||||
        return x.device.type == 'xpu' and kv_cache_device_check(x) \
 | 
			
		||||
            and hasattr(linear, "qtype") and \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue