fix llama3-8b npu long input stuck (#11613)

This commit is contained in:
Yishuo Wang 2024-07-18 11:08:17 +08:00 committed by GitHub
parent e5c0058c0e
commit f4077fa905
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,6 +26,7 @@ from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
from intel_npu_acceleration_library.backend import run_matmul
from intel_npu_acceleration_library.dtypes import NPUDtype
from typing import Optional, Union
import os
import torch
from torch.nn import Parameter
import uuid
@ -177,6 +178,14 @@ class QuantizedLinear(torch.nn.Module):
Returns:
torch.Tensor: result
"""
# we assume a Linear is lm_head when its out_features > 30000,
# if out_features > 100000, enable lm_head optimization automatically
if x.size(1) > 500 and (
(self.outC > 100_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") != "0") or
(self.outC > 30_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") == "1")
):
x = x[:, -1:, :]
if self.training:
invalidInputError(
False,