[NPU] lm head to cpu (#11943)

* lm head to cpu

* qwen2

* mv logic and add param to disable cpu_lm_head

* use env and lm_head opt to mp file

* fix

* update

* remove print
This commit is contained in:
Yina Chen 2024-08-28 11:34:07 +03:00 committed by GitHub
parent ec67ee7177
commit b38fb67bec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 31 additions and 3 deletions

View file

@ -153,7 +153,7 @@ class _BaseAutoModelClass:
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre
with torch.no_grad():
optimize_llm_pre(model)
optimize_llm_pre(model, qtype)
cls.load_convert(qtype, model, "cpu", *args, **kwargs)
create_npu_kernels(model)
model = model.eval()

View file

@ -56,7 +56,7 @@ def replace_with_QuantizedLinear(layer, qtype, device):
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear):
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
if qtype == "sym_int4_rtn":
# workaround for qwen2 & int4
if (layer.in_features == 3584 and layer.out_features == 152064) or \

View file

@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import importlib
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
def convert_forward(m, target_m, new_forward):
@ -25,7 +27,7 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward)
def optimize_llm_pre(model: torch.nn.Module):
def optimize_llm_pre(model: torch.nn.Module, qtype):
if model.config.model_type == "baichuan":
# process NormHead module in Baichuan2 7B
if hasattr(model, 'lm_head') and model.lm_head is not None:
@ -40,6 +42,32 @@ def optimize_llm_pre(model: torch.nn.Module):
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
# lm_head to cpu optimization
if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "1") != "0":
is_unsupported_model = (model.config.model_type == "llama"
and model.vocab_size > 32000)
if not is_unsupported_model:
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8
if qtype == "sym_int4_rtn":
lm_qtype = SYM_INT4
else:
lm_qtype = SYM_INT8
# lm_head opt to mp opt (llama, qwen2)
optimize_lm_head = model.config.model_type not in ["llama", "qwen2"]
new_linear = LowBitLinear(model.lm_head.in_features,
model.lm_head.out_features,
lm_qtype,
False,
optimize_lm_head=optimize_lm_head)
paramsLowBit = FP4Params(data=model.lm_head.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
qtype=lm_qtype,
in_features=model.lm_head.in_features).to("cpu")
new_linear._parameters['weight'] = paramsLowBit
model.lm_head = new_linear
def optimize_llm(
model: torch.nn.Module,