fix llama3-8b npu long input stuck (#11613)
This commit is contained in:
		
							parent
							
								
									e5c0058c0e
								
							
						
					
					
						commit
						f4077fa905
					
				
					 1 changed files with 9 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -26,6 +26,7 @@ from intel_npu_acceleration_library.nn.autograd import AutogradMatMul
 | 
			
		|||
from intel_npu_acceleration_library.backend import run_matmul
 | 
			
		||||
from intel_npu_acceleration_library.dtypes import NPUDtype
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from torch.nn import Parameter
 | 
			
		||||
import uuid
 | 
			
		||||
| 
						 | 
				
			
			@ -177,6 +178,14 @@ class QuantizedLinear(torch.nn.Module):
 | 
			
		|||
        Returns:
 | 
			
		||||
            torch.Tensor: result
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # we assume a Linear is lm_head when its out_features > 30000,
 | 
			
		||||
        # if out_features > 100000, enable lm_head optimization automatically
 | 
			
		||||
        if x.size(1) > 500 and (
 | 
			
		||||
            (self.outC > 100_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") != "0") or
 | 
			
		||||
            (self.outC > 30_000 and os.environ.get("IPEX_LLM_LAST_LM_HEAD") == "1")
 | 
			
		||||
        ):
 | 
			
		||||
            x = x[:, -1:, :]
 | 
			
		||||
        if self.training:
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue