LLM: replace torch.float32 with auto type (#9261)

This commit is contained in:
Jin Qiao 2023-10-24 17:12:13 +08:00 committed by GitHub
parent bd5215d75b
commit 90162264a3

View file

@ -104,7 +104,7 @@ def load_model(
device, load_8bit, cpu_offloading device, load_8bit, cpu_offloading
) )
if device == "cpu": if device == "cpu":
kwargs = {"torch_dtype": torch.float32} kwargs = {"torch_dtype": "auto"}
if CPU_ISA in ["avx512_bf16", "amx"]: if CPU_ISA in ["avx512_bf16", "amx"]:
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex