LLM: optimize chatglm2 8k input. (#10723)

* LLM: optimize chatglm2 8k input.

* rename.
This commit is contained in:
Cengguang Zhang 2024-04-10 16:59:06 +08:00 committed by GitHub
parent cd22cb8257
commit 4b024b7aac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -252,10 +252,31 @@ def chatglm2_quantized_attention_forward_8eb45c(
else:
key, value = key_layer, value_layer
if attention_mask is None:
context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True)
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5000 for now
if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
# split second dim to block size = 8
block_size = 8
query_split = torch.split(query_layer, block_size, dim=1)
key_split = torch.split(key, block_size, dim=1)
value_split = torch.split(value, block_size, dim=1)
context_layer = torch.empty(batch_size, n_head,
seq_len, head_dim).to(query_layer.device)
idx = 0
for q, k, v in zip(query_split, key_split, value_split):
if attention_mask is None:
result = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
result = F.scaled_dot_product_attention(q, k, v, attention_mask)
context_layer[:, idx:idx+q.shape[1], :, :] = result
idx = idx + q.shape[1]
else:
context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask)
if attention_mask is None:
context_layer = F.scaled_dot_product_attention(query_layer, key,
value, is_causal=True)
else:
context_layer = F.scaled_dot_product_attention(query_layer, key,
value, attention_mask)
context_layer = context_layer.to(query_layer.dtype)
if use_cache:
@ -517,15 +538,19 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5000 for now
if query_layer.dtype == torch.float16 and L >= 5000:
# split first dim 32 -> 8
query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1)
key_sp = torch.split(key_layer, 8, dim=1)
value_sp = torch.split(value_layer, 8, dim=1)
results = []
for q, k, v in zip(query_sp, key_sp, value_sp):
# split second dim to block size = 8
block_size = 8
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
key_split = torch.split(key_layer, block_size, dim=1)
value_split = torch.split(value_layer, block_size, dim=1)
batch_size, n_head, seq_len, head_dim = query_layer.shape
context_layer = torch.empty(batch_size, n_head, seq_len,
head_dim).to(query_layer.device).to(key_layer.dtype)
idx = 0
for q, k, v in zip(query_split, key_split, value_split):
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
results.append(result)
context_layer = torch.cat(results, dim=1)
context_layer[:, idx:idx+q.shape[1], :, :] = result
idx = idx + q.shape[1]
else:
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
key_layer,