fix fuse mlp when using q5_0 or fp8 (#9689)
This commit is contained in:
		
							parent
							
								
									82ac2dbf55
								
							
						
					
					
						commit
						9a330bfc2b
					
				
					 2 changed files with 4 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -26,6 +26,7 @@ from torch import nn
 | 
			
		|||
from torch.nn import functional as F
 | 
			
		||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
| 
						 | 
				
			
			@ -75,6 +76,7 @@ def baichuan_mlp_forward(
 | 
			
		|||
    x: torch.Tensor,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
 | 
			
		||||
            and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
 | 
			
		||||
            and not (self.training and x.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,6 +39,7 @@ except ImportError:
 | 
			
		|||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import rotate_half
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
 | 
			
		||||
apply_rotary_emb_func = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -214,6 +215,7 @@ def qwen_attention_forward(
 | 
			
		|||
 | 
			
		||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
 | 
			
		||||
            and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \
 | 
			
		||||
            and not (self.training and x.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue