diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 29520b44..dd9fca75 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -230,7 +230,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, if ( device.type == "xpu" and dtype in [torch.float, torch.half] - and head_dim in [64, 80, 96, 128] + and head_dim in [64, 80, 96, 128, 192, 256] ): # prepare scale scale = 1 / math.sqrt(head_dim) if scale is None else scale