fix llama3-8b npu long input stuck (#11613)
This commit is contained in:
parent
e5c0058c0e
commit
f4077fa905
1 changed files with 9 additions and 0 deletions
|
|
@ -26,6 +26,7 @@ from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
|
|||
from intel_npu_acceleration_library.backend import run_matmul
|
||||
from intel_npu_acceleration_library.dtypes import NPUDtype
|
||||
from typing import Optional, Union
|
||||
import os
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
import uuid
|
||||
|
|
@ -177,6 +178,14 @@ class QuantizedLinear(torch.nn.Module):
|
|||
Returns:
|
||||
torch.Tensor: result
|
||||
"""
|
||||
|
||||
# we assume a Linear is lm_head when its out_features > 30000,
|
||||
# if out_features > 100000, enable lm_head optimization automatically
|
||||
if x.size(1) > 500 and (
|
||||
(self.outC > 100_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") != "0") or
|
||||
(self.outC > 30_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") == "1")
|
||||
):
|
||||
x = x[:, -1:, :]
|
||||
if self.training:
|
||||
invalidInputError(
|
||||
False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue