LLM: support int4 fp16 chatglm2-6b 8k input. (#10648)
This commit is contained in:
parent
ab87b6ab21
commit
1a9b8204a4
1 changed files with 17 additions and 4 deletions
|
|
@ -512,6 +512,19 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
L, S = query_layer.shape[2], key_layer.shape[2]
|
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||||
if attention_mask is None and L == S:
|
if attention_mask is None and L == S:
|
||||||
|
# split tensor for memory block limitation
|
||||||
|
# support fp16 and set input length threshold at 5000 for now
|
||||||
|
if query_layer.dtype == torch.float16 and L >= 5000:
|
||||||
|
# split first dim 32 -> 8
|
||||||
|
query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1)
|
||||||
|
key_sp = torch.split(key_layer, 8, dim=1)
|
||||||
|
value_sp = torch.split(value_layer, 8, dim=1)
|
||||||
|
results = []
|
||||||
|
for q, k, v in zip(query_sp, key_sp, value_sp):
|
||||||
|
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
||||||
|
results.append(result)
|
||||||
|
context_layer = torch.cat(results, dim=1)
|
||||||
|
else:
|
||||||
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue