fix llama3-8b npu long input stuck (#11613)
This commit is contained in:
parent
e5c0058c0e
commit
f4077fa905
1 changed files with 9 additions and 0 deletions
|
|
@ -26,6 +26,7 @@ from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
|
||||||
from intel_npu_acceleration_library.backend import run_matmul
|
from intel_npu_acceleration_library.backend import run_matmul
|
||||||
from intel_npu_acceleration_library.dtypes import NPUDtype
|
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -177,6 +178,14 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: result
|
torch.Tensor: result
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# we assume a Linear is lm_head when its out_features > 30000,
|
||||||
|
# if out_features > 100000, enable lm_head optimization automatically
|
||||||
|
if x.size(1) > 500 and (
|
||||||
|
(self.outC > 100_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") != "0") or
|
||||||
|
(self.outC > 30_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") == "1")
|
||||||
|
):
|
||||||
|
x = x[:, -1:, :]
|
||||||
if self.training:
|
if self.training:
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
False,
|
False,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue