LLM: update split tensor conditions. (#10872)

* LLM: update split tensor condition.

* add cond for split tensor.

* update priority of env.

* fix style.

* update env name.
This commit is contained in:
Cengguang Zhang 2024-04-30 17:07:21 +08:00 committed by GitHub
parent 71f51ce589
commit 75dbf240ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 17 deletions

View file

@ -96,6 +96,19 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten
return key, value
def should_split_qkv_tensor(query_layer, bsz, n_head, seq_len):
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
elif query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5000 for now
return True
elif query_layer.element_size()*bsz*n_head*seq_len*seq_len >= 4*1024**3:
# attn_weight size larger than memory block limitation 4GB
return True
return False
def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
@ -250,9 +263,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
else:
key, value = key_layer, value_layer
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5000 for now
if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
# split second dim to block size = 8
block_size = 8
query_split = torch.split(query_layer, block_size, dim=1)
@ -529,9 +540,8 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
query_layer = query_layer.permute(1, 2, 0, 3)
L, S = query_layer.shape[2], key_layer.shape[2]
if attention_mask is None and L == S:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5000 for now
if query_layer.dtype == torch.float16 and L >= 5000:
batch_size, n_head, seq_len, head_dim = query_layer.shape
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
# split second dim to block size = 8
block_size = 8
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)

View file

@ -215,12 +215,18 @@ def should_use_fast_rope(self, query_states, position_ids):
return use_fuse_rope
def should_split_qkv_tensor(query_states, output_attentions):
if not output_attentions and query_states.dtype == torch.float16 and \
query_states.shape[2] >= 6800:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 6800 for now
return True
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
if not output_attentions:
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
elif query_states.dtype == torch.float16 and \
query_states.shape[2] >= 6800:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 6800 for now
return True
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
# attn_weight size larger than memory block limitation 4GB
return True
return False
@ -1029,7 +1035,8 @@ def llama_attention_forward_4_36_quantized(
if len(past_key_value.key_cache) <= self.layer_idx:
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
if should_split_qkv_tensor(query_states, output_attentions):
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
repeated_value_states, attention_mask,
bsz, q_len, kv_seq_len, self.head_dim,
@ -1403,7 +1410,7 @@ def llama_attention_forward_4_36_original(
def native_sdp(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
if should_split_qkv_tensor(query, output_attentions):
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads)
else:

View file

@ -74,9 +74,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
elif os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
warnings.warn(
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
"Please use `IPEX_LLM_QUANTIZE_KV_CACHE` instead."
@ -84,6 +82,8 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None:
return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
else:
return x.device.type == 'xpu' and kv_cache_device_check(x) \
and hasattr(linear, "qtype") and \