Initial support for quantized forward on CPU when quantization_group_size=0 (#12282)

* Initial support for quantized forward on CPU when quantization_group_size=0

* Style fix

* Style fix

* Small fix

* Small fix
This commit is contained in:
Yuwen Hu 2024-10-29 19:40:17 +08:00 committed by GitHub
parent 3feb58d1e4
commit 5a15098835
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 141 additions and 36 deletions

View file

@ -133,6 +133,7 @@ class _BaseAutoModelClass:
modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
mixed_precision = kwargs.pop('mixed_precision', False) mixed_precision = kwargs.pop('mixed_precision', False)
quantization_group_size = kwargs.pop("quantization_group_size", 0) quantization_group_size = kwargs.pop("quantization_group_size", 0)
mock_device = kwargs.pop('device', None) # For mock on CPU
invalidInputError( invalidInputError(
quantization_group_size in [0, 32, 64, 128], quantization_group_size in [0, 32, 64, 128],
@ -141,7 +142,6 @@ class _BaseAutoModelClass:
f"but got {quantization_group_size}" f"but got {quantization_group_size}"
) )
) )
_args = copy.deepcopy(args) _args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs) _kwargs = copy.deepcopy(kwargs)
@ -165,45 +165,54 @@ class _BaseAutoModelClass:
model.config.update({"bigdl_lcmu_enabled": False}) model.config.update({"bigdl_lcmu_enabled": False})
logger.info(f"Converting model, it may takes up to several minutes ...") logger.info(f"Converting model, it may takes up to several minutes ...")
from intel_npu_acceleration_library.compiler import create_npu_kernels
if optimize_model: if mock_device == "cpu":
invalidInputError(
max_prompt_len < max_context_len,
(
f"max_prompt_len ({max_prompt_len}) should be less"
" than max_context_len ({max_context_len})"
),
)
optimize_kwargs = {
"model": model,
"qtype": qtype,
"mixed_precision": mixed_precision,
"quantization_group_size": quantization_group_size,
"modules_to_not_convert": modules_to_not_convert,
"pipeline": pipeline,
"max_context_len": max_context_len,
"max_prompt_len": max_prompt_len,
"inter_pp": inter_pp,
"intra_pp": intra_pp,
"transpose_value_cache": transpose_value_cache
}
model = cls.optimize_npu_model(*args, **optimize_kwargs)
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)
with torch.no_grad(): with torch.no_grad():
cls.load_convert(qtype, model, "cpu", modules_to_not_convert, # Only mock quantization_group_size=0 for now
quantization_group_size, *args, **kwargs) cls.load_convert_cpu(qtype, model, "cpu", modules_to_not_convert, 0,
if hasattr(model, "llm"): *args, **kwargs)
create_npu_kernels(model.llm)
else:
create_npu_kernels(model)
model = model.eval() model = model.eval()
logger.info(f"Finish to convert model") logger.info(f"Finish to convert model")
model.config.update({"bigdl_transformers_low_bit": qtype}) else:
# add save_low_bit to pretrained model dynamically from intel_npu_acceleration_library.compiler import create_npu_kernels
model.save_low_bit = types.MethodType(save_low_bit, model)
if optimize_model:
invalidInputError(
max_prompt_len < max_context_len,
(
f"max_prompt_len ({max_prompt_len}) should be less"
" than max_context_len ({max_context_len})"
),
)
optimize_kwargs = {
"model": model,
"qtype": qtype,
"mixed_precision": mixed_precision,
"quantization_group_size": quantization_group_size,
"modules_to_not_convert": modules_to_not_convert,
"pipeline": pipeline,
"max_context_len": max_context_len,
"max_prompt_len": max_prompt_len,
"inter_pp": inter_pp,
"intra_pp": intra_pp,
"transpose_value_cache": transpose_value_cache
}
model = cls.optimize_npu_model(*args, **optimize_kwargs)
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)
with torch.no_grad():
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, *args, **kwargs)
if hasattr(model, "llm"):
create_npu_kernels(model.llm)
else:
create_npu_kernels(model)
model = model.eval()
logger.info(f"Finish to convert model")
model.config.update({"bigdl_transformers_low_bit": qtype})
# add save_low_bit to pretrained model dynamically
model.save_low_bit = types.MethodType(save_low_bit, model)
return model return model
@ -274,6 +283,15 @@ class _BaseAutoModelClass:
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=modules_to_not_convert,
group_size=group_size) group_size=group_size)
@classmethod
def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,
group_size=0, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_DequantizedLinear
replace_with_DequantizedLinear(optimize_model, q_k, device=device,
modules_to_not_convert=modules_to_not_convert,
group_size=group_size)
@classmethod @classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):

View file

@ -75,6 +75,19 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
group_size=group_size) group_size=group_size)
@module_optimization
def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
group_size):
from ipex_llm.transformers.npu_models.linear import DequantizedLinear
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device)
return DequantizedLinear(qweights, scale, layer.bias)
def convert_forward(m, target_m, new_forward): def convert_forward(m, target_m, new_forward):
if m.__class__ == target_m: if m.__class__ == target_m:
bound_method = new_forward.__get__(m, m.__class__) bound_method = new_forward.__get__(m, m.__class__)

View file

@ -200,3 +200,77 @@ class QuantizedLinear(torch.nn.Module):
if self.bias is None: if self.bias is None:
return out return out
return out + self.bias return out + self.bias
class DequantizedLinear(torch.nn.Module):
"""Torch Quantized Linear operation NPU backend."""
def __init__(
self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
"""Initialize the DequantizedLinear class.
Args:
weight (torch.Tensor): Linear operation quantized weight
scale (torch.Tensor): Quantization scale
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
Raises:
RuntimeError: Quantized weight must be in torch.int8 format
"""
super().__init__()
if weight.dtype not in (torch.int8, torch.uint8):
invalidInputError(
False,
(
f"Quantized weight must be in torch.(u)int8"
" dtype instead of {self.weight.dtype}"
)
)
if weight.dtype == torch.uint8:
weight = weight.view(torch.int8)
high_4bits = weight >> 4
low_4bits = (weight << 4) >> 4
combined_weight = torch.cat((low_4bits.unsqueeze(2), high_4bits.unsqueeze(2)), dim=2)
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
dequantized_weight = decompressed_weight.to(torch.float32) * \
torch.unsqueeze(scale.to(torch.float32), dim=1)
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
else:
dequantized_weight = weight.to(torch.float32) * \
torch.unsqueeze(scale.to(torch.float32), dim=1)
self.weight = Parameter(dequantized_weight.to(torch.float32),
requires_grad=False).contiguous()
self.bias = bias
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Torch module forward method.
Args:
x (torch.Tensor): Input tensor
Raises:
RuntimeError: Training is not supported for DequantizedLinear layer.
Use `.eval()` to do inference only
Returns:
torch.Tensor: result
"""
if self.training:
invalidInputError(
False,
(
"Training is not supported for DequantizedLinear layer."
"Use `.eval()` to do inference only"
)
)
out = torch.matmul(x.to(torch.float32), torch.transpose(self.weight.data, 0, 1))
if self.bias is None:
return out
return out + self.bias