add sdp support for stablelm 3b (#11473)
This commit is contained in:
parent
cf8eb7b128
commit
39bcb33a67
2 changed files with 3 additions and 3 deletions
|
|
@ -93,7 +93,7 @@ def stablelm_model_forward(
|
|||
):
|
||||
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 96, 128]
|
||||
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
|
||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ def use_sdp(q_len, kv_len, head_dim, query_states):
|
|||
return (
|
||||
query_states.device.type == "xpu"
|
||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||
and head_dim in [64, 96, 128]
|
||||
and head_dim in [64, 80, 96, 128]
|
||||
and q_len != kv_len # next token
|
||||
and q_len <= 32 # lookup
|
||||
)
|
||||
|
|
@ -347,7 +347,7 @@ def use_sdp_fp8(q_len, kv_len, query_states):
|
|||
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
||||
return (
|
||||
q_len == kv_len # first token
|
||||
and head_dim in [64, 96, 128] # for now
|
||||
and head_dim in [64, 80, 96, 128] # for now
|
||||
and query_states.device.type == "xpu" # GPU
|
||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||
and not query_states.requires_grad and not training # not training
|
||||
|
|
|
|||
Loading…
Reference in a new issue