LLM: support int4 fp16 chatglm2-6b 8k input. (#10648)
This commit is contained in:
parent
ab87b6ab21
commit
1a9b8204a4
1 changed files with 17 additions and 4 deletions
|
|
@ -512,6 +512,19 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
|||
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||
if attention_mask is None and L == S:
|
||||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 5000 for now
|
||||
if query_layer.dtype == torch.float16 and L >= 5000:
|
||||
# split first dim 32 -> 8
|
||||
query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1)
|
||||
key_sp = torch.split(key_layer, 8, dim=1)
|
||||
value_sp = torch.split(value_layer, 8, dim=1)
|
||||
results = []
|
||||
for q, k, v in zip(query_sp, key_sp, value_sp):
|
||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
||||
results.append(result)
|
||||
context_layer = torch.cat(results, dim=1)
|
||||
else:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
||||
key_layer,
|
||||
value_layer,
|
||||
|
|
|
|||
Loading…
Reference in a new issue