Fix qwen2 & int4 on NPU (#11646)

This commit is contained in:
binbin Deng 2024-07-24 13:14:39 +08:00 committed by GitHub
parent 1b3b46e54d
commit 777e61d8c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -57,6 +57,12 @@ def replace_with_QuantizedLinear(layer, qtype, device):
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype] iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear): if isinstance(layer, torch.nn.Linear):
if qtype == "sym_int4_rtn":
# workaround for qwen2 & int4
if (layer.in_features == 3584 and layer.out_features == 152064) or \
(layer.in_features == 18944 and layer.out_features == 3584):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device) qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device)
return QuantizedLinear(qweights, scale, layer.bias) return QuantizedLinear(qweights, scale, layer.bias)