LLM: update split tensor conditions. (#10872)
* LLM: update split tensor condition. * add cond for split tensor. * update priority of env. * fix style. * update env name.
This commit is contained in:
parent
71f51ce589
commit
75dbf240ec
3 changed files with 34 additions and 17 deletions
|
|
@ -96,6 +96,19 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten
|
|||
return key, value
|
||||
|
||||
|
||||
def should_split_qkv_tensor(query_layer, bsz, n_head, seq_len):
|
||||
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
||||
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
||||
elif query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
|
||||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 5000 for now
|
||||
return True
|
||||
elif query_layer.element_size()*bsz*n_head*seq_len*seq_len >= 4*1024**3:
|
||||
# attn_weight size larger than memory block limitation 4GB
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def chatglm_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
@ -250,9 +263,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
|||
else:
|
||||
key, value = key_layer, value_layer
|
||||
|
||||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 5000 for now
|
||||
if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
|
||||
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
|
||||
# split second dim to block size = 8
|
||||
block_size = 8
|
||||
query_split = torch.split(query_layer, block_size, dim=1)
|
||||
|
|
@ -529,9 +540,8 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
|||
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||
if attention_mask is None and L == S:
|
||||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 5000 for now
|
||||
if query_layer.dtype == torch.float16 and L >= 5000:
|
||||
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
||||
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
|
||||
# split second dim to block size = 8
|
||||
block_size = 8
|
||||
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
|
||||
|
|
|
|||
|
|
@ -215,12 +215,18 @@ def should_use_fast_rope(self, query_states, position_ids):
|
|||
return use_fuse_rope
|
||||
|
||||
|
||||
def should_split_qkv_tensor(query_states, output_attentions):
|
||||
if not output_attentions and query_states.dtype == torch.float16 and \
|
||||
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||
if not output_attentions:
|
||||
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
||||
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
||||
elif query_states.dtype == torch.float16 and \
|
||||
query_states.shape[2] >= 6800:
|
||||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 6800 for now
|
||||
return True
|
||||
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
|
||||
# attn_weight size larger than memory block limitation 4GB
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -1029,7 +1035,8 @@ def llama_attention_forward_4_36_quantized(
|
|||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
if should_split_qkv_tensor(query_states, output_attentions):
|
||||
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||
q_len, kv_seq_len, output_attentions):
|
||||
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
||||
repeated_value_states, attention_mask,
|
||||
bsz, q_len, kv_seq_len, self.head_dim,
|
||||
|
|
@ -1403,7 +1410,7 @@ def llama_attention_forward_4_36_original(
|
|||
|
||||
def native_sdp(query, key, value, attention_mask,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||
if should_split_qkv_tensor(query, output_attentions):
|
||||
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -74,9 +74,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
|||
|
||||
|
||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||
if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
||||
elif os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||
warnings.warn(
|
||||
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
|
||||
"Please use `IPEX_LLM_QUANTIZE_KV_CACHE` instead."
|
||||
|
|
@ -84,6 +82,8 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
|||
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
|
||||
elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None:
|
||||
return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
|
||||
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
||||
else:
|
||||
return x.device.type == 'xpu' and kv_cache_device_check(x) \
|
||||
and hasattr(linear, "qtype") and \
|
||||
|
|
|
|||
Loading…
Reference in a new issue