LLM: add esimd sdp support for chatglm3 (#10205)
* add esimd sdp support * fix style
This commit is contained in:
parent
7cbc2429a6
commit
34ee1aa91f
1 changed files with 25 additions and 16 deletions
|
|
@ -25,6 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import use_esimd_sdp
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -515,7 +516,15 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
||||||
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
is_causal=True)
|
is_causal=True).to(key_layer.dtype)
|
||||||
|
else:
|
||||||
|
if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
|
||||||
|
query_layer.shape[-1], query_layer):
|
||||||
|
import linear_fp16_esimd
|
||||||
|
attn_output = linear_fp16_esimd.sdp_forward(query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer)
|
||||||
|
context_layer = attn_output.view(query_layer.shape)
|
||||||
else:
|
else:
|
||||||
head_dim = query_layer.size(-1)
|
head_dim = query_layer.size(-1)
|
||||||
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue