diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 3aa99e2e..78e10c2c 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -395,7 +395,7 @@ class FP4Params(torch.nn.Parameter): return self._shape @overload - def to(self: T, device: Optional[Union[int, device]]=..., + def to(self: T, device: Optional[Union[int, torch.device]]=..., dtype: Optional[Union[dtype, str]]=..., non_blocking: bool=...,) -> T: ...