add sdp support for stablelm 3b (#11473)

This commit is contained in:
Yishuo Wang 2024-07-01 14:56:15 +08:00 committed by GitHub
parent cf8eb7b128
commit 39bcb33a67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View file

@ -93,7 +93,7 @@ def stablelm_model_forward(
): ):
# IPEX-LLM OPT: kv cache and quantize kv cache # IPEX-LLM OPT: kv cache and quantize kv cache
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 96, 128] use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids)) and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
if use_cache: if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):

View file

@ -329,7 +329,7 @@ def use_sdp(q_len, kv_len, head_dim, query_states):
return ( return (
query_states.device.type == "xpu" query_states.device.type == "xpu"
and query_states.dtype in [torch.float, torch.half] # fp32/fp16 and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and head_dim in [64, 96, 128] and head_dim in [64, 80, 96, 128]
and q_len != kv_len # next token and q_len != kv_len # next token
and q_len <= 32 # lookup and q_len <= 32 # lookup
) )
@ -347,7 +347,7 @@ def use_sdp_fp8(q_len, kv_len, query_states):
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training): def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
return ( return (
q_len == kv_len # first token q_len == kv_len # first token
and head_dim in [64, 96, 128] # for now and head_dim in [64, 80, 96, 128] # for now
and query_states.device.type == "xpu" # GPU and query_states.device.type == "xpu" # GPU
and query_states.dtype in [torch.float, torch.half] # fp32/fp16 and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and not query_states.requires_grad and not training # not training and not query_states.requires_grad and not training # not training