LLM: update split tensor conditions. (#10872)
* LLM: update split tensor condition. * add cond for split tensor. * update priority of env. * fix style. * update env name.
This commit is contained in:
		
							parent
							
								
									71f51ce589
								
							
						
					
					
						commit
						75dbf240ec
					
				
					 3 changed files with 34 additions and 17 deletions
				
			
		| 
						 | 
				
			
			@ -96,6 +96,19 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten
 | 
			
		|||
    return key, value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_split_qkv_tensor(query_layer, bsz, n_head, seq_len):
 | 
			
		||||
    if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
 | 
			
		||||
        return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
 | 
			
		||||
    elif query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
 | 
			
		||||
        # split tensor for memory block limitation
 | 
			
		||||
        # support fp16 and set input length threshold at 5000 for now
 | 
			
		||||
        return True
 | 
			
		||||
    elif query_layer.element_size()*bsz*n_head*seq_len*seq_len >= 4*1024**3:
 | 
			
		||||
        # attn_weight size larger than memory block limitation 4GB
 | 
			
		||||
        return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
| 
						 | 
				
			
			@ -250,9 +263,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
 | 
			
		|||
        else:
 | 
			
		||||
            key, value = key_layer, value_layer
 | 
			
		||||
 | 
			
		||||
        # split tensor for memory block limitation
 | 
			
		||||
        # support fp16 and set input length threshold at 5000 for now
 | 
			
		||||
        if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
 | 
			
		||||
        if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
 | 
			
		||||
            # split second dim to block size = 8
 | 
			
		||||
            block_size = 8
 | 
			
		||||
            query_split = torch.split(query_layer, block_size, dim=1)
 | 
			
		||||
| 
						 | 
				
			
			@ -529,9 +540,8 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
 | 
			
		|||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
        L, S = query_layer.shape[2], key_layer.shape[2]
 | 
			
		||||
        if attention_mask is None and L == S:
 | 
			
		||||
            # split tensor for memory block limitation
 | 
			
		||||
            # support fp16 and set input length threshold at 5000 for now
 | 
			
		||||
            if query_layer.dtype == torch.float16 and L >= 5000:
 | 
			
		||||
            batch_size, n_head, seq_len, head_dim = query_layer.shape
 | 
			
		||||
            if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
 | 
			
		||||
                # split second dim to block size = 8
 | 
			
		||||
                block_size = 8
 | 
			
		||||
                query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -215,12 +215,18 @@ def should_use_fast_rope(self, query_states, position_ids):
 | 
			
		|||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_split_qkv_tensor(query_states, output_attentions):
 | 
			
		||||
    if not output_attentions and query_states.dtype == torch.float16 and \
 | 
			
		||||
            query_states.shape[2] >= 6800:
 | 
			
		||||
        # split tensor for memory block limitation
 | 
			
		||||
        # support fp16 and set input length threshold at 6800 for now
 | 
			
		||||
        return True
 | 
			
		||||
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
 | 
			
		||||
            return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
 | 
			
		||||
        elif query_states.dtype == torch.float16 and \
 | 
			
		||||
                query_states.shape[2] >= 6800:
 | 
			
		||||
            # split tensor for memory block limitation
 | 
			
		||||
            # support fp16 and set input length threshold at 6800 for now
 | 
			
		||||
            return True
 | 
			
		||||
        elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
 | 
			
		||||
            # attn_weight size larger than memory block limitation 4GB
 | 
			
		||||
            return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1029,7 +1035,8 @@ def llama_attention_forward_4_36_quantized(
 | 
			
		|||
    if len(past_key_value.key_cache) <= self.layer_idx:
 | 
			
		||||
        repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
        if should_split_qkv_tensor(query_states, output_attentions):
 | 
			
		||||
        if should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
			
		||||
                                   q_len, kv_seq_len, output_attentions):
 | 
			
		||||
            attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
 | 
			
		||||
                                                         repeated_value_states, attention_mask,
 | 
			
		||||
                                                         bsz, q_len, kv_seq_len, self.head_dim,
 | 
			
		||||
| 
						 | 
				
			
			@ -1403,7 +1410,7 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
 | 
			
		||||
def native_sdp(query, key, value, attention_mask,
 | 
			
		||||
               bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
 | 
			
		||||
    if should_split_qkv_tensor(query, output_attentions):
 | 
			
		||||
    if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
 | 
			
		||||
        return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
 | 
			
		||||
                                           bsz, q_len, kv_seq_len, head_dim, num_heads)
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,9 +74,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
			
		||||
    if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
 | 
			
		||||
        return os.environ["IPEX_LLM_LOW_MEM"] == "1"
 | 
			
		||||
    elif os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
			
		||||
    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
 | 
			
		||||
            "Please use `IPEX_LLM_QUANTIZE_KV_CACHE` instead."
 | 
			
		||||
| 
						 | 
				
			
			@ -84,6 +82,8 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
			
		|||
        return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
 | 
			
		||||
    elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None:
 | 
			
		||||
        return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
 | 
			
		||||
    elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
 | 
			
		||||
        return os.environ["IPEX_LLM_LOW_MEM"] == "1"
 | 
			
		||||
    else:
 | 
			
		||||
        return x.device.type == 'xpu' and kv_cache_device_check(x) \
 | 
			
		||||
            and hasattr(linear, "qtype") and \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue