LLM: Fix CPU qlora dtype convert issue (#9394)
This commit is contained in:
parent
34449cb4bb
commit
40cead6b5b
2 changed files with 1 additions and 3 deletions
|
|
@ -368,8 +368,6 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, A, weight):
|
def forward(ctx, A, weight):
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
A = A.to(torch.get_autocast_dtype())
|
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
|
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
|
||||||
weight._shape[0] * weight._shape[1])
|
weight._shape[0] * weight._shape[1])
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ def get_autocast_dtype(x):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif x.device.type == "cpu":
|
elif x.device.type == "cpu":
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_cpu_enabled():
|
||||||
return torch.get_autocast_cpu_dtype()
|
return torch.get_autocast_cpu_dtype()
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue