LLM: Fix CPU qlora dtype convert issue (#9394)
This commit is contained in:
parent
34449cb4bb
commit
40cead6b5b
2 changed files with 1 additions and 3 deletions
|
|
@ -368,8 +368,6 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, A, weight):
|
||||
if torch.is_autocast_enabled():
|
||||
A = A.to(torch.get_autocast_dtype())
|
||||
ctx.is_empty = False
|
||||
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
|
||||
weight._shape[0] * weight._shape[1])
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ def get_autocast_dtype(x):
|
|||
else:
|
||||
return None
|
||||
elif x.device.type == "cpu":
|
||||
if torch.is_autocast_enabled():
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
return torch.get_autocast_cpu_dtype()
|
||||
else:
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in a new issue