add sdp support for stablelm 3b (#11473)

This commit is contained in:
Yishuo Wang 2024-07-01 14:56:15 +08:00 committed by GitHub
parent cf8eb7b128
commit 39bcb33a67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View file

@ -93,7 +93,7 @@ def stablelm_model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 96, 128]
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):

View file

@ -329,7 +329,7 @@ def use_sdp(q_len, kv_len, head_dim, query_states):
return (
query_states.device.type == "xpu"
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and head_dim in [64, 96, 128]
and head_dim in [64, 80, 96, 128]
and q_len != kv_len # next token
and q_len <= 32 # lookup
)
@ -347,7 +347,7 @@ def use_sdp_fp8(q_len, kv_len, query_states):
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
return (
q_len == kv_len # first token
and head_dim in [64, 96, 128] # for now
and head_dim in [64, 80, 96, 128] # for now
and query_states.device.type == "xpu" # GPU
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and not query_states.requires_grad and not training # not training