Initial support for quantized forward on CPU when quantization_group_size=0 (#12282)
* Initial support for quantized forward on CPU when quantization_group_size=0 * Style fix * Style fix * Small fix * Small fix
This commit is contained in:
parent
3feb58d1e4
commit
5a15098835
3 changed files with 141 additions and 36 deletions
|
|
@ -133,6 +133,7 @@ class _BaseAutoModelClass:
|
|||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
||||
mixed_precision = kwargs.pop('mixed_precision', False)
|
||||
quantization_group_size = kwargs.pop("quantization_group_size", 0)
|
||||
mock_device = kwargs.pop('device', None) # For mock on CPU
|
||||
|
||||
invalidInputError(
|
||||
quantization_group_size in [0, 32, 64, 128],
|
||||
|
|
@ -141,7 +142,6 @@ class _BaseAutoModelClass:
|
|||
f"but got {quantization_group_size}"
|
||||
)
|
||||
)
|
||||
|
||||
_args = copy.deepcopy(args)
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
|
||||
|
|
@ -165,6 +165,15 @@ class _BaseAutoModelClass:
|
|||
model.config.update({"bigdl_lcmu_enabled": False})
|
||||
|
||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||
|
||||
if mock_device == "cpu":
|
||||
with torch.no_grad():
|
||||
# Only mock quantization_group_size=0 for now
|
||||
cls.load_convert_cpu(qtype, model, "cpu", modules_to_not_convert, 0,
|
||||
*args, **kwargs)
|
||||
model = model.eval()
|
||||
logger.info(f"Finish to convert model")
|
||||
else:
|
||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||
|
||||
if optimize_model:
|
||||
|
|
@ -274,6 +283,15 @@ class _BaseAutoModelClass:
|
|||
modules_to_not_convert=modules_to_not_convert,
|
||||
group_size=group_size)
|
||||
|
||||
@classmethod
|
||||
def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,
|
||||
group_size=0, *arg, **kwarg):
|
||||
from ipex_llm.transformers.npu_models.convert import replace_with_DequantizedLinear
|
||||
|
||||
replace_with_DequantizedLinear(optimize_model, q_k, device=device,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
group_size=group_size)
|
||||
|
||||
@classmethod
|
||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -75,6 +75,19 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
|
|||
group_size=group_size)
|
||||
|
||||
|
||||
@module_optimization
|
||||
def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||
group_size):
|
||||
from ipex_llm.transformers.npu_models.linear import DequantizedLinear
|
||||
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
iqtype = ggml_tensor_qtype[qtype]
|
||||
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
||||
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
|
||||
iqtype, device=device)
|
||||
return DequantizedLinear(qweights, scale, layer.bias)
|
||||
|
||||
|
||||
def convert_forward(m, target_m, new_forward):
|
||||
if m.__class__ == target_m:
|
||||
bound_method = new_forward.__get__(m, m.__class__)
|
||||
|
|
|
|||
|
|
@ -200,3 +200,77 @@ class QuantizedLinear(torch.nn.Module):
|
|||
if self.bias is None:
|
||||
return out
|
||||
return out + self.bias
|
||||
|
||||
|
||||
class DequantizedLinear(torch.nn.Module):
|
||||
"""Torch Quantized Linear operation NPU backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize the DequantizedLinear class.
|
||||
Args:
|
||||
weight (torch.Tensor): Linear operation quantized weight
|
||||
scale (torch.Tensor): Quantization scale
|
||||
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
||||
Defaults to None.
|
||||
Raises:
|
||||
RuntimeError: Quantized weight must be in torch.int8 format
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if weight.dtype not in (torch.int8, torch.uint8):
|
||||
invalidInputError(
|
||||
False,
|
||||
(
|
||||
f"Quantized weight must be in torch.(u)int8"
|
||||
" dtype instead of {self.weight.dtype}"
|
||||
)
|
||||
)
|
||||
|
||||
if weight.dtype == torch.uint8:
|
||||
weight = weight.view(torch.int8)
|
||||
high_4bits = weight >> 4
|
||||
low_4bits = (weight << 4) >> 4
|
||||
|
||||
combined_weight = torch.cat((low_4bits.unsqueeze(2), high_4bits.unsqueeze(2)), dim=2)
|
||||
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
|
||||
dequantized_weight = decompressed_weight.to(torch.float32) * \
|
||||
torch.unsqueeze(scale.to(torch.float32), dim=1)
|
||||
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
|
||||
else:
|
||||
dequantized_weight = weight.to(torch.float32) * \
|
||||
torch.unsqueeze(scale.to(torch.float32), dim=1)
|
||||
self.weight = Parameter(dequantized_weight.to(torch.float32),
|
||||
requires_grad=False).contiguous()
|
||||
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Torch module forward method.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
Raises:
|
||||
RuntimeError: Training is not supported for DequantizedLinear layer.
|
||||
Use `.eval()` to do inference only
|
||||
Returns:
|
||||
torch.Tensor: result
|
||||
"""
|
||||
|
||||
if self.training:
|
||||
invalidInputError(
|
||||
False,
|
||||
(
|
||||
"Training is not supported for DequantizedLinear layer."
|
||||
"Use `.eval()` to do inference only"
|
||||
)
|
||||
)
|
||||
|
||||
out = torch.matmul(x.to(torch.float32), torch.transpose(self.weight.data, 0, 1))
|
||||
|
||||
if self.bias is None:
|
||||
return out
|
||||
return out + self.bias
|
||||
|
|
|
|||
Loading…
Reference in a new issue