LLM: optimize chatglm2 8k input. (#10723)
* LLM: optimize chatglm2 8k input. * rename.
This commit is contained in:
parent
cd22cb8257
commit
4b024b7aac
1 changed files with 36 additions and 11 deletions
|
|
@ -252,10 +252,31 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
|||
else:
|
||||
key, value = key_layer, value_layer
|
||||
|
||||
if attention_mask is None:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True)
|
||||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 5000 for now
|
||||
if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
|
||||
# split second dim to block size = 8
|
||||
block_size = 8
|
||||
query_split = torch.split(query_layer, block_size, dim=1)
|
||||
key_split = torch.split(key, block_size, dim=1)
|
||||
value_split = torch.split(value, block_size, dim=1)
|
||||
context_layer = torch.empty(batch_size, n_head,
|
||||
seq_len, head_dim).to(query_layer.device)
|
||||
idx = 0
|
||||
for q, k, v in zip(query_split, key_split, value_split):
|
||||
if attention_mask is None:
|
||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
else:
|
||||
result = F.scaled_dot_product_attention(q, k, v, attention_mask)
|
||||
context_layer[:, idx:idx+q.shape[1], :, :] = result
|
||||
idx = idx + q.shape[1]
|
||||
else:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask)
|
||||
if attention_mask is None:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer, key,
|
||||
value, is_causal=True)
|
||||
else:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer, key,
|
||||
value, attention_mask)
|
||||
context_layer = context_layer.to(query_layer.dtype)
|
||||
|
||||
if use_cache:
|
||||
|
|
@ -517,15 +538,19 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
|||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 5000 for now
|
||||
if query_layer.dtype == torch.float16 and L >= 5000:
|
||||
# split first dim 32 -> 8
|
||||
query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1)
|
||||
key_sp = torch.split(key_layer, 8, dim=1)
|
||||
value_sp = torch.split(value_layer, 8, dim=1)
|
||||
results = []
|
||||
for q, k, v in zip(query_sp, key_sp, value_sp):
|
||||
# split second dim to block size = 8
|
||||
block_size = 8
|
||||
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
|
||||
key_split = torch.split(key_layer, block_size, dim=1)
|
||||
value_split = torch.split(value_layer, block_size, dim=1)
|
||||
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
||||
context_layer = torch.empty(batch_size, n_head, seq_len,
|
||||
head_dim).to(query_layer.device).to(key_layer.dtype)
|
||||
idx = 0
|
||||
for q, k, v in zip(query_split, key_split, value_split):
|
||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
||||
results.append(result)
|
||||
context_layer = torch.cat(results, dim=1)
|
||||
context_layer[:, idx:idx+q.shape[1], :, :] = result
|
||||
idx = idx + q.shape[1]
|
||||
else:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
||||
key_layer,
|
||||
|
|
|
|||
Loading…
Reference in a new issue