FIX: Qwen1.5-GPTQ-Int4 inference error (#11432)
* merge_qkv if quant_method is 'gptq' * fix python style checks * refactor * update GPU example
This commit is contained in:
		
							parent
							
								
									99cd16ef9f
								
							
						
					
					
						commit
						ab9f7f3ac5
					
				
					 5 changed files with 40 additions and 17 deletions
				
			
		| 
						 | 
					@ -18,7 +18,7 @@ conda activate llm
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
 | 
					pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
 | 
				
			||||||
pip install transformers==4.34.0
 | 
					pip install transformers==4.34.0
 | 
				
			||||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
 | 
					BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
 | 
				
			||||||
pip install optimum==0.14.0
 | 
					pip install optimum==1.14.0
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
On Windows:
 | 
					On Windows:
 | 
				
			||||||
| 
						 | 
					@ -30,7 +30,7 @@ pip install --pre --upgrade ipex-llm[all]
 | 
				
			||||||
pip install transformers==4.34.0
 | 
					pip install transformers==4.34.0
 | 
				
			||||||
set BUILD_CUDA_EXT=0
 | 
					set BUILD_CUDA_EXT=0
 | 
				
			||||||
pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
 | 
					pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
 | 
				
			||||||
pip install optimum==0.14.0
 | 
					pip install optimum==1.14.0
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Run
 | 
					### 2. Run
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,7 @@ import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.transformers import AutoModelForCausalLM
 | 
					from ipex_llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
from transformers import LlamaTokenizer, GPTQConfig
 | 
					from transformers import LlamaTokenizer, AutoTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# you could tune the prompt based on your own model,
 | 
					# you could tune the prompt based on your own model,
 | 
				
			||||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
					# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
				
			||||||
| 
						 | 
					@ -50,7 +50,10 @@ if __name__ == '__main__':
 | 
				
			||||||
                                                 trust_remote_code=True,)
 | 
					                                                 trust_remote_code=True,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load tokenizer
 | 
					    # Load tokenizer
 | 
				
			||||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					    if "qwen" in model_path.lower():
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Generate predicted tokens
 | 
					    # Generate predicted tokens
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,7 @@ import torch
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from ipex_llm.transformers import AutoModelForCausalLM
 | 
					from ipex_llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
from transformers import AutoTokenizer, GPTQConfig
 | 
					from transformers import AutoTokenizer, AutoTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# you could tune the prompt based on your own model,
 | 
					# you could tune the prompt based on your own model,
 | 
				
			||||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
					# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
				
			||||||
| 
						 | 
					@ -48,9 +48,11 @@ if __name__ == '__main__':
 | 
				
			||||||
                                                 torch_dtype=torch.float,
 | 
					                                                 torch_dtype=torch.float,
 | 
				
			||||||
                                                 trust_remote_code=True,).to("xpu")
 | 
					                                                 trust_remote_code=True,).to("xpu")
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    print(model)
 | 
					 | 
				
			||||||
    # Load tokenizer
 | 
					    # Load tokenizer
 | 
				
			||||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					    if "qwen" in model_path.lower():
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Generate predicted tokens
 | 
					    # Generate predicted tokens
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -732,10 +732,17 @@ def _optimize_pre(model):
 | 
				
			||||||
        model.apply(split_mlp)
 | 
					        model.apply(split_mlp)
 | 
				
			||||||
    # for qwen2
 | 
					    # for qwen2
 | 
				
			||||||
    if model.config.model_type == "qwen2":
 | 
					    if model.config.model_type == "qwen2":
 | 
				
			||||||
        from ipex_llm.transformers.models.qwen2 import merge_qkv
 | 
					        # Skip merge_qkv and padding_mlp if quant_method is 'gptq'
 | 
				
			||||||
        model.apply(merge_qkv)
 | 
					        should_apply_merge_qkv = (
 | 
				
			||||||
        from ipex_llm.transformers.models.qwen2 import padding_mlp
 | 
					            not hasattr(model.config, "quantization_config") or
 | 
				
			||||||
        model.apply(padding_mlp)
 | 
					            not hasattr(model.config.quantization_config, "quant_method") or
 | 
				
			||||||
 | 
					            model.config.quantization_config.quant_method != "gptq"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if should_apply_merge_qkv:
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.qwen2 import merge_qkv
 | 
				
			||||||
 | 
					            model.apply(merge_qkv)
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.qwen2 import padding_mlp
 | 
				
			||||||
 | 
					            model.apply(padding_mlp)
 | 
				
			||||||
    if model.config.model_type == "qwen2_moe":
 | 
					    if model.config.model_type == "qwen2_moe":
 | 
				
			||||||
        from ipex_llm.transformers.models.qwen2_moe import merge_qkv
 | 
					        from ipex_llm.transformers.models.qwen2_moe import merge_qkv
 | 
				
			||||||
        model.apply(merge_qkv)
 | 
					        model.apply(merge_qkv)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -405,12 +405,23 @@ def qwen2_attention_forward(
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    qkv = self.qkv_proj(hidden_states)
 | 
					    if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
 | 
				
			||||||
    qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
					        qkv = self.qkv_proj(hidden_states)
 | 
				
			||||||
    qkv = qkv.transpose(1, 2)
 | 
					        qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
				
			||||||
    query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
					        qkv = qkv.transpose(1, 2)
 | 
				
			||||||
                                                        self.num_key_value_heads,
 | 
					        query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
				
			||||||
                                                        self.num_key_value_heads], dim=1)
 | 
					                                                            self.num_key_value_heads,
 | 
				
			||||||
 | 
					                                                            self.num_key_value_heads], dim=1)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # when quant_method is 'gptq'
 | 
				
			||||||
 | 
					        query_states = self.q_proj(hidden_states)
 | 
				
			||||||
 | 
					        key_states = self.k_proj(hidden_states)
 | 
				
			||||||
 | 
					        value_states = self.v_proj(hidden_states)
 | 
				
			||||||
 | 
					        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \
 | 
				
			||||||
 | 
					                               .transpose(1, 2)
 | 
				
			||||||
 | 
					        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \
 | 
				
			||||||
 | 
					                                   .transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    kv_seq_len = key_states.shape[-2]
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue