fix qlora finetune example (#12769)
This commit is contained in:
		
							parent
							
								
									094a25b740
								
							
						
					
					
						commit
						9697197f3e
					
				
					 8 changed files with 55 additions and 39 deletions
				
			
		| 
						 | 
				
			
			@ -17,9 +17,9 @@ conda create -n llm python=3.11
 | 
			
		|||
conda activate llm
 | 
			
		||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
			
		||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
			
		||||
pip install transformers==4.36.0 datasets
 | 
			
		||||
pip install transformers==4.45.0 "trl<0.12.0" datasets
 | 
			
		||||
pip install peft==0.10.0
 | 
			
		||||
pip install bitsandbytes scipy
 | 
			
		||||
pip install bitsandbytes==0.45.1 scipy
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Configures OneAPI environment variables
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,7 @@ import torch
 | 
			
		|||
import os
 | 
			
		||||
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
from peft import LoraConfig
 | 
			
		||||
from transformers import BitsAndBytesConfig
 | 
			
		||||
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
			
		||||
| 
						 | 
				
			
			@ -43,7 +43,7 @@ if __name__ == "__main__":
 | 
			
		|||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
    dataset_path = args.dataset
 | 
			
		||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    if dataset_path.endswith(".json") or dataset_path.endswith(".jsonl"):
 | 
			
		||||
        data = load_dataset("json", data_files=dataset_path)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,9 +17,9 @@ conda create -n llm python=3.11
 | 
			
		|||
conda activate llm
 | 
			
		||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
			
		||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
			
		||||
pip install transformers==4.36.0 datasets
 | 
			
		||||
pip install transformers==4.45.0 "trl<0.12.0" datasets
 | 
			
		||||
pip install peft==0.10.0
 | 
			
		||||
pip install bitsandbytes scipy trl==0.9.6
 | 
			
		||||
pip install bitsandbytes==0.45.1 scipy
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Configures OneAPI environment variables
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,7 @@ import torch
 | 
			
		|||
import os
 | 
			
		||||
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
from peft import LoraConfig
 | 
			
		||||
from transformers import BitsAndBytesConfig
 | 
			
		||||
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
			
		||||
| 
						 | 
				
			
			@ -44,7 +44,7 @@ if __name__ == "__main__":
 | 
			
		|||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
    dataset_path = args.dataset
 | 
			
		||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    # Avoid tokenizer doesn't have a padding token
 | 
			
		||||
    if tokenizer.pad_token is None:
 | 
			
		||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -51,7 +51,8 @@ from torch import Tensor, dtype, nn
 | 
			
		|||
from operator import mul
 | 
			
		||||
from functools import reduce
 | 
			
		||||
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
			
		||||
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name
 | 
			
		||||
from ipex_llm.transformers.utils import is_autocast_enabled, get_autocast_dtype
 | 
			
		||||
from ipex_llm.transformers.utils import get_xpu_device_name
 | 
			
		||||
from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T", bound="torch.nn.Module")
 | 
			
		||||
| 
						 | 
				
			
			@ -527,8 +528,8 @@ class MatMulLowBit(torch.autograd.Function):
 | 
			
		|||
        A, weight = ctx.tensors
 | 
			
		||||
        grad_A, grad_weight = None, None
 | 
			
		||||
        if req_gradA:
 | 
			
		||||
            if torch.xpu.is_autocast_xpu_enabled():
 | 
			
		||||
                grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
 | 
			
		||||
            if is_autocast_enabled("xpu"):
 | 
			
		||||
                grad_output = grad_output.to(get_autocast_dtype("xpu"))
 | 
			
		||||
            if weight.qtype == NF4:
 | 
			
		||||
                dequant_weight = xe_linear.dequant(A,
 | 
			
		||||
                                                   weight.data.view(torch.uint8),
 | 
			
		||||
| 
						 | 
				
			
			@ -615,7 +616,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        is_training = self.training and not torch.is_inference_mode_enabled()
 | 
			
		||||
        if is_training:
 | 
			
		||||
            # below logic is only for training
 | 
			
		||||
            autocast_dtype = get_autocast_dtype(x)
 | 
			
		||||
            autocast_dtype = get_autocast_dtype(x.device.type)
 | 
			
		||||
            if self.compute_dtype is not None and x.device.type == "xpu":
 | 
			
		||||
                x = x.to(self.compute_dtype)  # solve GC issue for unlora module
 | 
			
		||||
            elif autocast_dtype is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -109,7 +109,7 @@ class LoraLowBitLinear(Module, LoraLayer):
 | 
			
		|||
            self.qa_pool = torch.nn.Identity()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        autocast_dtype = get_autocast_dtype(x)
 | 
			
		||||
        autocast_dtype = get_autocast_dtype(x.device.type)
 | 
			
		||||
        if x.device.type == "xpu":
 | 
			
		||||
            # force to use bf16 on gpu
 | 
			
		||||
            x = x.to(torch.bfloat16)
 | 
			
		||||
| 
						 | 
				
			
			@ -177,7 +177,7 @@ class LoraBF16Linear(Module, LoraLayer):
 | 
			
		|||
        self.is_target_conv_1d_layer = is_target_conv_1d_layer
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        autocast_dtype = get_autocast_dtype(x)
 | 
			
		||||
        autocast_dtype = get_autocast_dtype(x.device.type)
 | 
			
		||||
        if x.device.type == "xpu":
 | 
			
		||||
            # force to use bf16 on gpu
 | 
			
		||||
            x = x.to(torch.bfloat16)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -138,26 +138,39 @@ def fix_key(key):
 | 
			
		|||
    return key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_autocast_dtype(x):
 | 
			
		||||
def is_autocast_enabled(device_type: str):
 | 
			
		||||
    if torch.__version__ >= '2.3':
 | 
			
		||||
        if torch.is_autocast_enabled(x.device.type):
 | 
			
		||||
            return torch.get_autocast_dtype(x.device.type)
 | 
			
		||||
        return torch.is_autocast_enabled(device_type)
 | 
			
		||||
    else:
 | 
			
		||||
        if device_type == "xpu":
 | 
			
		||||
            return torch.xpu.is_autocast_xpu_enabled()
 | 
			
		||||
        elif device_type == "cpu":
 | 
			
		||||
            return torch.is_autocast_cpu_enabled()
 | 
			
		||||
        else:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Device type {device_type} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_autocast_dtype(device_type: str):
 | 
			
		||||
    if torch.__version__ >= '2.3':
 | 
			
		||||
        if torch.is_autocast_enabled(device_type):
 | 
			
		||||
            return torch.get_autocast_dtype(device_type)
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
    else:
 | 
			
		||||
        if x.device.type == "xpu":
 | 
			
		||||
        if device_type == "xpu":
 | 
			
		||||
            if torch.xpu.is_autocast_xpu_enabled():
 | 
			
		||||
                return torch.xpu.get_autocast_xpu_dtype()
 | 
			
		||||
            else:
 | 
			
		||||
                return None
 | 
			
		||||
        elif x.device.type == "cpu":
 | 
			
		||||
        elif device_type == "cpu":
 | 
			
		||||
            if torch.is_autocast_cpu_enabled():
 | 
			
		||||
                return torch.get_autocast_cpu_dtype()
 | 
			
		||||
            else:
 | 
			
		||||
                return None
 | 
			
		||||
        else:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Device {x.device} is not supported.")
 | 
			
		||||
                              f"Device type {device_type} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_xpu_device_name(device: torch.device):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -107,6 +107,8 @@ except ModuleNotFoundError:
 | 
			
		|||
    np = None  # type: ignore[assignment]
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.utils import is_autocast_enabled, get_autocast_dtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cast(value, dtype):
 | 
			
		||||
    if isinstance(value, torch.Tensor):
 | 
			
		||||
| 
						 | 
				
			
			@ -155,12 +157,12 @@ def custom_fwd(fwd=None, *, cast_inputs=None):
 | 
			
		|||
 | 
			
		||||
    @functools.wraps(fwd)
 | 
			
		||||
    def decorate_fwd(*args, **kwargs):
 | 
			
		||||
        args[0]._dtype = torch.xpu.get_autocast_xpu_dtype()
 | 
			
		||||
        args[0]._dtype = get_autocast_dtype("xpu")
 | 
			
		||||
        if cast_inputs is None:
 | 
			
		||||
            args[0]._fwd_used_autocast = torch.xpu.is_autocast_xpu_enabled()
 | 
			
		||||
            args[0]._fwd_used_autocast = is_autocast_enabled("xpu")
 | 
			
		||||
            return fwd(*args, **kwargs)
 | 
			
		||||
        else:
 | 
			
		||||
            autocast_context = torch.xpu.is_autocast_xpu_enabled()
 | 
			
		||||
            autocast_context = is_autocast_enabled("xpu")
 | 
			
		||||
            args[0]._fwd_used_autocast = False
 | 
			
		||||
            if autocast_context:
 | 
			
		||||
                with torch.xpu.autocast(enabled=False):
 | 
			
		||||
| 
						 | 
				
			
			@ -184,7 +186,7 @@ def custom_bwd(bwd):
 | 
			
		|||
 | 
			
		||||
    @functools.wraps(bwd)
 | 
			
		||||
    def decorate_bwd(*args, **kwargs):
 | 
			
		||||
        with torch.xpu.autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
 | 
			
		||||
        with torch.autocast("xpu", enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
 | 
			
		||||
            return bwd(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    return decorate_bwd
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue