LLM: Fix CPU qlora dtype convert issue (#9394)

This commit is contained in:
Wang, Jian4 2023-11-09 14:34:01 +08:00 committed by GitHub
parent 34449cb4bb
commit 40cead6b5b
2 changed files with 1 additions and 3 deletions

View file

@ -368,8 +368,6 @@ class MatMulLowBitCPU(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, weight): def forward(ctx, A, weight):
if torch.is_autocast_enabled():
A = A.to(torch.get_autocast_dtype())
ctx.is_empty = False ctx.is_empty = False
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape, x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
weight._shape[0] * weight._shape[1]) weight._shape[0] * weight._shape[1])

View file

@ -142,7 +142,7 @@ def get_autocast_dtype(x):
else: else:
return None return None
elif x.device.type == "cpu": elif x.device.type == "cpu":
if torch.is_autocast_enabled(): if torch.is_autocast_cpu_enabled():
return torch.get_autocast_cpu_dtype() return torch.get_autocast_cpu_dtype()
else: else:
return None return None