fix llama3-8b npu long input stuck (#11613)

This commit is contained in:
Yishuo Wang 2024-07-18 11:08:17 +08:00 committed by GitHub
parent e5c0058c0e
commit f4077fa905
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,6 +26,7 @@ from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
from intel_npu_acceleration_library.backend import run_matmul from intel_npu_acceleration_library.backend import run_matmul
from intel_npu_acceleration_library.dtypes import NPUDtype from intel_npu_acceleration_library.dtypes import NPUDtype
from typing import Optional, Union from typing import Optional, Union
import os
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
import uuid import uuid
@ -177,6 +178,14 @@ class QuantizedLinear(torch.nn.Module):
Returns: Returns:
torch.Tensor: result torch.Tensor: result
""" """
# we assume a Linear is lm_head when its out_features > 30000,
# if out_features > 100000, enable lm_head optimization automatically
if x.size(1) > 500 and (
(self.outC > 100_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") != "0") or
(self.outC > 30_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") == "1")
):
x = x[:, -1:, :]
if self.training: if self.training:
invalidInputError( invalidInputError(
False, False,