[NPU] lm head to cpu (#11943)
* lm head to cpu * qwen2 * mv logic and add param to disable cpu_lm_head * use env and lm_head opt to mp file * fix * update * remove print
This commit is contained in:
parent
ec67ee7177
commit
b38fb67bec
3 changed files with 31 additions and 3 deletions
|
|
@ -153,7 +153,7 @@ class _BaseAutoModelClass:
|
|||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre
|
||||
|
||||
with torch.no_grad():
|
||||
optimize_llm_pre(model)
|
||||
optimize_llm_pre(model, qtype)
|
||||
cls.load_convert(qtype, model, "cpu", *args, **kwargs)
|
||||
create_npu_kernels(model)
|
||||
model = model.eval()
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def replace_with_QuantizedLinear(layer, qtype, device):
|
|||
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
iqtype = ggml_tensor_qtype[qtype]
|
||||
if isinstance(layer, torch.nn.Linear):
|
||||
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
||||
if qtype == "sym_int4_rtn":
|
||||
# workaround for qwen2 & int4
|
||||
if (layer.in_features == 3584 and layer.out_features == 152064) or \
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import importlib
|
||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
||||
|
||||
|
||||
def convert_forward(m, target_m, new_forward):
|
||||
|
|
@ -25,7 +27,7 @@ def convert_forward(m, target_m, new_forward):
|
|||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
||||
|
||||
def optimize_llm_pre(model: torch.nn.Module):
|
||||
def optimize_llm_pre(model: torch.nn.Module, qtype):
|
||||
if model.config.model_type == "baichuan":
|
||||
# process NormHead module in Baichuan2 7B
|
||||
if hasattr(model, 'lm_head') and model.lm_head is not None:
|
||||
|
|
@ -40,6 +42,32 @@ def optimize_llm_pre(model: torch.nn.Module):
|
|||
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
|
||||
model.apply(pre_compute_inv_freq)
|
||||
|
||||
# lm_head to cpu optimization
|
||||
if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "1") != "0":
|
||||
is_unsupported_model = (model.config.model_type == "llama"
|
||||
and model.vocab_size > 32000)
|
||||
if not is_unsupported_model:
|
||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8
|
||||
if qtype == "sym_int4_rtn":
|
||||
lm_qtype = SYM_INT4
|
||||
else:
|
||||
lm_qtype = SYM_INT8
|
||||
# lm_head opt to mp opt (llama, qwen2)
|
||||
optimize_lm_head = model.config.model_type not in ["llama", "qwen2"]
|
||||
new_linear = LowBitLinear(model.lm_head.in_features,
|
||||
model.lm_head.out_features,
|
||||
lm_qtype,
|
||||
False,
|
||||
optimize_lm_head=optimize_lm_head)
|
||||
paramsLowBit = FP4Params(data=model.lm_head.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False,
|
||||
_shape=None,
|
||||
qtype=lm_qtype,
|
||||
in_features=model.lm_head.in_features).to("cpu")
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
model.lm_head = new_linear
|
||||
|
||||
|
||||
def optimize_llm(
|
||||
model: torch.nn.Module,
|
||||
|
|
|
|||
Loading…
Reference in a new issue