Initial support for quantized forward on CPU when quantization_group_size=0 (#12282)
* Initial support for quantized forward on CPU when quantization_group_size=0 * Style fix * Style fix * Small fix * Small fix
This commit is contained in:
parent
3feb58d1e4
commit
5a15098835
3 changed files with 141 additions and 36 deletions
|
|
@ -133,6 +133,7 @@ class _BaseAutoModelClass:
|
||||||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
||||||
mixed_precision = kwargs.pop('mixed_precision', False)
|
mixed_precision = kwargs.pop('mixed_precision', False)
|
||||||
quantization_group_size = kwargs.pop("quantization_group_size", 0)
|
quantization_group_size = kwargs.pop("quantization_group_size", 0)
|
||||||
|
mock_device = kwargs.pop('device', None) # For mock on CPU
|
||||||
|
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
quantization_group_size in [0, 32, 64, 128],
|
quantization_group_size in [0, 32, 64, 128],
|
||||||
|
|
@ -141,7 +142,6 @@ class _BaseAutoModelClass:
|
||||||
f"but got {quantization_group_size}"
|
f"but got {quantization_group_size}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
_args = copy.deepcopy(args)
|
_args = copy.deepcopy(args)
|
||||||
_kwargs = copy.deepcopy(kwargs)
|
_kwargs = copy.deepcopy(kwargs)
|
||||||
|
|
||||||
|
|
@ -165,6 +165,15 @@ class _BaseAutoModelClass:
|
||||||
model.config.update({"bigdl_lcmu_enabled": False})
|
model.config.update({"bigdl_lcmu_enabled": False})
|
||||||
|
|
||||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
|
|
||||||
|
if mock_device == "cpu":
|
||||||
|
with torch.no_grad():
|
||||||
|
# Only mock quantization_group_size=0 for now
|
||||||
|
cls.load_convert_cpu(qtype, model, "cpu", modules_to_not_convert, 0,
|
||||||
|
*args, **kwargs)
|
||||||
|
model = model.eval()
|
||||||
|
logger.info(f"Finish to convert model")
|
||||||
|
else:
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
|
|
@ -274,6 +283,15 @@ class _BaseAutoModelClass:
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
group_size=group_size)
|
group_size=group_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,
|
||||||
|
group_size=0, *arg, **kwarg):
|
||||||
|
from ipex_llm.transformers.npu_models.convert import replace_with_DequantizedLinear
|
||||||
|
|
||||||
|
replace_with_DequantizedLinear(optimize_model, q_k, device=device,
|
||||||
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
|
group_size=group_size)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
|
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,19 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||||
group_size=group_size)
|
group_size=group_size)
|
||||||
|
|
||||||
|
|
||||||
|
@module_optimization
|
||||||
|
def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||||
|
group_size):
|
||||||
|
from ipex_llm.transformers.npu_models.linear import DequantizedLinear
|
||||||
|
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
||||||
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
|
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
||||||
|
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
|
||||||
|
iqtype, device=device)
|
||||||
|
return DequantizedLinear(qweights, scale, layer.bias)
|
||||||
|
|
||||||
|
|
||||||
def convert_forward(m, target_m, new_forward):
|
def convert_forward(m, target_m, new_forward):
|
||||||
if m.__class__ == target_m:
|
if m.__class__ == target_m:
|
||||||
bound_method = new_forward.__get__(m, m.__class__)
|
bound_method = new_forward.__get__(m, m.__class__)
|
||||||
|
|
|
||||||
|
|
@ -200,3 +200,77 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
return out
|
return out
|
||||||
return out + self.bias
|
return out + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class DequantizedLinear(torch.nn.Module):
|
||||||
|
"""Torch Quantized Linear operation NPU backend."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the DequantizedLinear class.
|
||||||
|
Args:
|
||||||
|
weight (torch.Tensor): Linear operation quantized weight
|
||||||
|
scale (torch.Tensor): Quantization scale
|
||||||
|
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
|
||||||
|
Defaults to None.
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Quantized weight must be in torch.int8 format
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if weight.dtype not in (torch.int8, torch.uint8):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
(
|
||||||
|
f"Quantized weight must be in torch.(u)int8"
|
||||||
|
" dtype instead of {self.weight.dtype}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight.dtype == torch.uint8:
|
||||||
|
weight = weight.view(torch.int8)
|
||||||
|
high_4bits = weight >> 4
|
||||||
|
low_4bits = (weight << 4) >> 4
|
||||||
|
|
||||||
|
combined_weight = torch.cat((low_4bits.unsqueeze(2), high_4bits.unsqueeze(2)), dim=2)
|
||||||
|
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
|
||||||
|
dequantized_weight = decompressed_weight.to(torch.float32) * \
|
||||||
|
torch.unsqueeze(scale.to(torch.float32), dim=1)
|
||||||
|
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
|
||||||
|
else:
|
||||||
|
dequantized_weight = weight.to(torch.float32) * \
|
||||||
|
torch.unsqueeze(scale.to(torch.float32), dim=1)
|
||||||
|
self.weight = Parameter(dequantized_weight.to(torch.float32),
|
||||||
|
requires_grad=False).contiguous()
|
||||||
|
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Torch module forward method.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Training is not supported for DequantizedLinear layer.
|
||||||
|
Use `.eval()` to do inference only
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: result
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
(
|
||||||
|
"Training is not supported for DequantizedLinear layer."
|
||||||
|
"Use `.eval()` to do inference only"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.matmul(x.to(torch.float32), torch.transpose(self.weight.data, 0, 1))
|
||||||
|
|
||||||
|
if self.bias is None:
|
||||||
|
return out
|
||||||
|
return out + self.bias
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue