diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 56a53a7e..3e793f4c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -133,6 +133,7 @@ class _BaseAutoModelClass: modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) mixed_precision = kwargs.pop('mixed_precision', False) quantization_group_size = kwargs.pop("quantization_group_size", 0) + mock_device = kwargs.pop('device', None) # For mock on CPU invalidInputError( quantization_group_size in [0, 32, 64, 128], @@ -141,7 +142,6 @@ class _BaseAutoModelClass: f"but got {quantization_group_size}" ) ) - _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) @@ -165,45 +165,54 @@ class _BaseAutoModelClass: model.config.update({"bigdl_lcmu_enabled": False}) logger.info(f"Converting model, it may takes up to several minutes ...") - from intel_npu_acceleration_library.compiler import create_npu_kernels - if optimize_model: - invalidInputError( - max_prompt_len < max_context_len, - ( - f"max_prompt_len ({max_prompt_len}) should be less" - " than max_context_len ({max_context_len})" - ), - ) - optimize_kwargs = { - "model": model, - "qtype": qtype, - "mixed_precision": mixed_precision, - "quantization_group_size": quantization_group_size, - "modules_to_not_convert": modules_to_not_convert, - "pipeline": pipeline, - "max_context_len": max_context_len, - "max_prompt_len": max_prompt_len, - "inter_pp": inter_pp, - "intra_pp": intra_pp, - "transpose_value_cache": transpose_value_cache - } - model = cls.optimize_npu_model(*args, **optimize_kwargs) - else: - from ipex_llm.transformers.npu_models.convert import optimize_llm - optimize_llm(model) + if mock_device == "cpu": with torch.no_grad(): - cls.load_convert(qtype, model, "cpu", modules_to_not_convert, - quantization_group_size, *args, **kwargs) - if hasattr(model, "llm"): - create_npu_kernels(model.llm) - else: - create_npu_kernels(model) + # Only mock quantization_group_size=0 for now + cls.load_convert_cpu(qtype, model, "cpu", modules_to_not_convert, 0, + *args, **kwargs) model = model.eval() logger.info(f"Finish to convert model") - model.config.update({"bigdl_transformers_low_bit": qtype}) - # add save_low_bit to pretrained model dynamically - model.save_low_bit = types.MethodType(save_low_bit, model) + else: + from intel_npu_acceleration_library.compiler import create_npu_kernels + + if optimize_model: + invalidInputError( + max_prompt_len < max_context_len, + ( + f"max_prompt_len ({max_prompt_len}) should be less" + " than max_context_len ({max_context_len})" + ), + ) + optimize_kwargs = { + "model": model, + "qtype": qtype, + "mixed_precision": mixed_precision, + "quantization_group_size": quantization_group_size, + "modules_to_not_convert": modules_to_not_convert, + "pipeline": pipeline, + "max_context_len": max_context_len, + "max_prompt_len": max_prompt_len, + "inter_pp": inter_pp, + "intra_pp": intra_pp, + "transpose_value_cache": transpose_value_cache + } + model = cls.optimize_npu_model(*args, **optimize_kwargs) + else: + from ipex_llm.transformers.npu_models.convert import optimize_llm + optimize_llm(model) + with torch.no_grad(): + cls.load_convert(qtype, model, "cpu", modules_to_not_convert, + quantization_group_size, *args, **kwargs) + if hasattr(model, "llm"): + create_npu_kernels(model.llm) + else: + create_npu_kernels(model) + model = model.eval() + logger.info(f"Finish to convert model") + model.config.update({"bigdl_transformers_low_bit": qtype}) + # add save_low_bit to pretrained model dynamically + model.save_low_bit = types.MethodType(save_low_bit, model) return model @@ -274,6 +283,15 @@ class _BaseAutoModelClass: modules_to_not_convert=modules_to_not_convert, group_size=group_size) + @classmethod + def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert, + group_size=0, *arg, **kwarg): + from ipex_llm.transformers.npu_models.convert import replace_with_DequantizedLinear + + replace_with_DequantizedLinear(optimize_model, q_k, device=device, + modules_to_not_convert=modules_to_not_convert, + group_size=group_size) + @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index a6b7a1cb..32f682ea 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -75,6 +75,19 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, group_size=group_size) +@module_optimization +def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert, + group_size): + from ipex_llm.transformers.npu_models.linear import DequantizedLinear + from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype + from ipex_llm.ggml.quantize import ggml_tensor_qtype + iqtype = ggml_tensor_qtype[qtype] + if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"): + qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), + iqtype, device=device) + return DequantizedLinear(qweights, scale, layer.bias) + + def convert_forward(m, target_m, new_forward): if m.__class__ == target_m: bound_method = new_forward.__get__(m, m.__class__) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/linear.py b/python/llm/src/ipex_llm/transformers/npu_models/linear.py index 9fb5d525..2c4b5f37 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/linear.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/linear.py @@ -200,3 +200,77 @@ class QuantizedLinear(torch.nn.Module): if self.bias is None: return out return out + self.bias + + +class DequantizedLinear(torch.nn.Module): + """Torch Quantized Linear operation NPU backend.""" + + def __init__( + self, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + """Initialize the DequantizedLinear class. + Args: + weight (torch.Tensor): Linear operation quantized weight + scale (torch.Tensor): Quantization scale + bias (Optional[torch.Tensor], optional): Linear operation optional bias. + Defaults to None. + Raises: + RuntimeError: Quantized weight must be in torch.int8 format + """ + super().__init__() + + if weight.dtype not in (torch.int8, torch.uint8): + invalidInputError( + False, + ( + f"Quantized weight must be in torch.(u)int8" + " dtype instead of {self.weight.dtype}" + ) + ) + + if weight.dtype == torch.uint8: + weight = weight.view(torch.int8) + high_4bits = weight >> 4 + low_4bits = (weight << 4) >> 4 + + combined_weight = torch.cat((low_4bits.unsqueeze(2), high_4bits.unsqueeze(2)), dim=2) + decompressed_weight = combined_weight.view(combined_weight.size(0), -1) + dequantized_weight = decompressed_weight.to(torch.float32) * \ + torch.unsqueeze(scale.to(torch.float32), dim=1) + self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous() + else: + dequantized_weight = weight.to(torch.float32) * \ + torch.unsqueeze(scale.to(torch.float32), dim=1) + self.weight = Parameter(dequantized_weight.to(torch.float32), + requires_grad=False).contiguous() + + self.bias = bias + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Torch module forward method. + Args: + x (torch.Tensor): Input tensor + Raises: + RuntimeError: Training is not supported for DequantizedLinear layer. + Use `.eval()` to do inference only + Returns: + torch.Tensor: result + """ + + if self.training: + invalidInputError( + False, + ( + "Training is not supported for DequantizedLinear layer." + "Use `.eval()` to do inference only" + ) + ) + + out = torch.matmul(x.to(torch.float32), torch.transpose(self.weight.data, 0, 1)) + + if self.bias is None: + return out + return out + self.bias