[NPU] Add env to enable scale search (#12462)

* add env enable scale search

* address comment

* move logic
This commit is contained in:
Yina Chen 2024-11-28 11:06:00 +02:00 committed by GitHub
parent d272f6b471
commit 1b533a105c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import os
import torch import torch
import importlib import importlib
from ipex_llm.transformers.npu_models.linear import QuantizedLinear from ipex_llm.transformers.npu_models.linear import QuantizedLinear
@ -69,8 +70,10 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
(layer.in_features == 18944 and layer.out_features == 3584): (layer.in_features == 18944 and layer.out_features == 3584):
qtype = "sym_int8_rtn" qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype] iqtype = ggml_tensor_qtype[qtype]
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device) iqtype, device=device,
enable_scale_search=enable_scale_search)
return QuantizedLinear(qweights, scale, layer.bias, return QuantizedLinear(qweights, scale, layer.bias,
group_size=group_size) group_size=group_size)
@ -83,8 +86,10 @@ def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype] iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"): if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device) iqtype, device=device,
enable_scale_search=enable_scale_search)
return DequantizedLinear(qweights, scale, layer.bias) return DequantizedLinear(qweights, scale, layer.bias)