Add axolotl v0.3.0 with ipex-llm on Intel GPU (#10717)
* Add axolotl v0.3.0 support on Intel GPU. * Add finetune example on llama-2-7B with Alpaca dataset.
This commit is contained in:
		
							parent
							
								
									0ccd7bfca9
								
							
						
					
					
						commit
						b727767f00
					
				
					 5 changed files with 462 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -9,6 +9,7 @@ This folder contains examples of running different training mode with IPEX-LLM o
 | 
			
		|||
- [DPO](DPO): examples of running DPO finetuning
 | 
			
		||||
- [common](common): common templates and utility classes in finetuning examples
 | 
			
		||||
- [HF-PEFT](HF-PEFT): run finetuning on Intel GPU using Hugging Face PEFT code without modification
 | 
			
		||||
- [axolotl](axolotl): LLM finetuning on Intel GPU using axolotl without writing code
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Troubleshooting
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										76
									
								
								python/llm/example/GPU/LLM-Finetuning/axolotl/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								python/llm/example/GPU/LLM-Finetuning/axolotl/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,76 @@
 | 
			
		|||
# Finetune LLM on Intel GPU using axolotl without writing code
 | 
			
		||||
 | 
			
		||||
This example demonstrates how to easily run LLM finetuning application using axolotl and IPEX-LLM 4bit optimizations with [Intel GPUs](../../../README.md). By applying IPEX-LLM patch, you could use axolotl on Intel GPUs using IPEX-LLM optimization without writing code.
 | 
			
		||||
 | 
			
		||||
Note, this example is just used for illustrating related usage and don't guarantee convergence of training.
 | 
			
		||||
 | 
			
		||||
### 0. Requirements
 | 
			
		||||
 | 
			
		||||
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
 | 
			
		||||
 | 
			
		||||
### 1. Install
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
conda create -n llm python=3.11
 | 
			
		||||
conda activate llm
 | 
			
		||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
			
		||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
			
		||||
pip install transformers==4.34.0 datasets
 | 
			
		||||
pip install fire peft==0.5.0
 | 
			
		||||
# install axolotl v0.3.0
 | 
			
		||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
 | 
			
		||||
cd axolotl
 | 
			
		||||
git checkout v0.3.0
 | 
			
		||||
# replace default requirements.txt in axolotl to avoid conflict
 | 
			
		||||
cp ../requirements.txt .
 | 
			
		||||
pip install -e .
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Configures OneAPI environment variables and accelerate
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Config `accelerate`
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
accelerate config
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Ensure `use_cpu` is disable in config (`~/.cache/huggingface/accelerate/default_config.yaml`).
 | 
			
		||||
 | 
			
		||||
### 3. Finetune
 | 
			
		||||
 | 
			
		||||
This example shows how to run [Alpaca QLoRA finetune on Llama-2](https://github.com/artidoro/qlora) directly on Intel GPU, based on [axolotl Llama-2 qlora example](https://github.com/OpenAccess-AI-Collective/axolotl/blob/v0.3.0/examples/llama-2/qlora.yml).
 | 
			
		||||
 | 
			
		||||
Modify parameters in `qlora.yml` based on your requirements.
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
accelerate launch finetune.py qlora.yml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Output in console
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
{'eval_loss': 0.9382301568984985, 'eval_runtime': 6.2513, 'eval_samples_per_second': 3.199, 'eval_steps_per_second': 3.199, 'epoch': 0.36}
 | 
			
		||||
{'loss': 0.944, 'learning_rate': 0.00019752490425051743, 'epoch': 0.38}
 | 
			
		||||
{'loss': 1.0179, 'learning_rate': 0.00019705675197106016, 'epoch': 0.4}
 | 
			
		||||
{'loss': 0.9346, 'learning_rate': 0.00019654872959986937, 'epoch': 0.41}
 | 
			
		||||
{'loss': 0.9747, 'learning_rate': 0.0001960010458282326, 'epoch': 0.43}
 | 
			
		||||
{'loss': 0.8928, 'learning_rate': 0.00019541392564000488, 'epoch': 0.45}
 | 
			
		||||
{'loss': 0.9317, 'learning_rate': 0.00019478761021918728, 'epoch': 0.47}
 | 
			
		||||
{'loss': 1.0534, 'learning_rate': 0.00019412235685085035, 'epoch': 0.49}
 | 
			
		||||
{'loss': 0.8777, 'learning_rate': 0.00019341843881544372, 'epoch': 0.5}
 | 
			
		||||
{'loss': 0.9447, 'learning_rate': 0.00019267614527653488, 'epoch': 0.52}
 | 
			
		||||
{'loss': 0.9651, 'learning_rate': 0.00019189578116202307, 'epoch': 0.54}
 | 
			
		||||
{'loss': 0.9067, 'learning_rate': 0.00019107766703887764, 'epoch': 0.56}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 4. Other examples
 | 
			
		||||
 | 
			
		||||
Please refer to [axolotl examples](https://github.com/OpenAccess-AI-Collective/axolotl/tree/v0.3.0/examples) for more models. Download `xxx.yml` and replace `qlora.yml` with new `xxx.yml`.
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
accelerate launch finetune.py xxx.yml
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										280
									
								
								python/llm/example/GPU/LLM-Finetuning/axolotl/finetune.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										280
									
								
								python/llm/example/GPU/LLM-Finetuning/axolotl/finetune.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,280 @@
 | 
			
		|||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
 | 
			
		||||
 | 
			
		||||
import importlib
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Dict, List, Optional, Union
 | 
			
		||||
 | 
			
		||||
from ipex_llm import llm_patch
 | 
			
		||||
llm_patch(train=True)
 | 
			
		||||
import fire
 | 
			
		||||
import torch
 | 
			
		||||
import transformers
 | 
			
		||||
import yaml
 | 
			
		||||
 | 
			
		||||
# add src to the pythonpath so we don't need to pip install this
 | 
			
		||||
from art import text2art
 | 
			
		||||
from transformers import GenerationConfig, TextStreamer
 | 
			
		||||
 | 
			
		||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
 | 
			
		||||
from axolotl.logging_config import configure_logging
 | 
			
		||||
from axolotl.train import TrainDatasetMeta, train
 | 
			
		||||
from axolotl.utils.config import normalize_config, validate_config
 | 
			
		||||
from axolotl.utils.data import prepare_dataset
 | 
			
		||||
from axolotl.utils.dict import DictDefault
 | 
			
		||||
from axolotl.utils.distributed import is_main_process
 | 
			
		||||
from axolotl.utils.models import load_tokenizer
 | 
			
		||||
from axolotl.utils.tokenization import check_dataset_labels
 | 
			
		||||
from axolotl.utils.wandb import setup_wandb_env_vars
 | 
			
		||||
 | 
			
		||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 | 
			
		||||
src_dir = os.path.join(project_root, "src")
 | 
			
		||||
sys.path.insert(0, src_dir)
 | 
			
		||||
 | 
			
		||||
configure_logging()
 | 
			
		||||
LOG = logging.getLogger("axolotl.scripts")
 | 
			
		||||
 | 
			
		||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_axolotl_text_art(suffix=None):
 | 
			
		||||
    font = "nancyj"
 | 
			
		||||
    ascii_text = "  axolotl"
 | 
			
		||||
    if suffix:
 | 
			
		||||
        ascii_text += f"  x  {suffix}"
 | 
			
		||||
    ascii_art = text2art(" axolotl", font=font)
 | 
			
		||||
 | 
			
		||||
    if is_main_process():
 | 
			
		||||
        print(ascii_art)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_multi_line_input() -> Optional[str]:
 | 
			
		||||
    print("Give me an instruction (Ctrl + D to finish): ")
 | 
			
		||||
    instruction = ""
 | 
			
		||||
    for line in sys.stdin:
 | 
			
		||||
        instruction += line  # pylint: disable=consider-using-join
 | 
			
		||||
    # instruction = pathlib.Path("/proc/self/fd/0").read_text()
 | 
			
		||||
    return instruction
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def do_merge_lora(
 | 
			
		||||
    *,
 | 
			
		||||
    cfg: DictDefault,
 | 
			
		||||
    cli_args: TrainerCliArgs,
 | 
			
		||||
):
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
 | 
			
		||||
    safe_serialization = cfg.save_safetensors is True
 | 
			
		||||
 | 
			
		||||
    LOG.info("running merge of LoRA with base model")
 | 
			
		||||
    model = model.merge_and_unload()
 | 
			
		||||
    model.to(dtype=torch.float16)
 | 
			
		||||
 | 
			
		||||
    if cfg.local_rank == 0:
 | 
			
		||||
        LOG.info("saving merged model")
 | 
			
		||||
        model.save_pretrained(
 | 
			
		||||
            str(Path(cfg.output_dir) / "merged"),
 | 
			
		||||
            safe_serialization=safe_serialization,
 | 
			
		||||
        )
 | 
			
		||||
        tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def shard(
 | 
			
		||||
    *,
 | 
			
		||||
    cfg: DictDefault,
 | 
			
		||||
    cli_args: TrainerCliArgs,
 | 
			
		||||
):
 | 
			
		||||
    model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
 | 
			
		||||
    safe_serialization = cfg.save_safetensors is True
 | 
			
		||||
    LOG.debug("Re-saving model w/ sharding")
 | 
			
		||||
    model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def do_inference(
 | 
			
		||||
    *,
 | 
			
		||||
    cfg: DictDefault,
 | 
			
		||||
    cli_args: TrainerCliArgs,
 | 
			
		||||
):
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
 | 
			
		||||
    prompter = cli_args.prompter
 | 
			
		||||
    default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
 | 
			
		||||
 | 
			
		||||
    for token, symbol in default_tokens.items():
 | 
			
		||||
        # If the token isn't already specified in the config, add it
 | 
			
		||||
        if not (cfg.special_tokens and token in cfg.special_tokens):
 | 
			
		||||
            tokenizer.add_special_tokens({token: symbol})
 | 
			
		||||
 | 
			
		||||
    prompter_module = None
 | 
			
		||||
    if prompter:
 | 
			
		||||
        prompter_module = getattr(
 | 
			
		||||
            importlib.import_module("axolotl.prompters"), prompter
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if cfg.landmark_attention:
 | 
			
		||||
        from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
 | 
			
		||||
 | 
			
		||||
        set_model_mem_id(model, tokenizer)
 | 
			
		||||
        model.set_mem_cache_args(
 | 
			
		||||
            max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    model = model.to(cfg.device)
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
        print("=" * 80)
 | 
			
		||||
        # support for multiline inputs
 | 
			
		||||
        instruction = get_multi_line_input()
 | 
			
		||||
        if not instruction:
 | 
			
		||||
            return
 | 
			
		||||
        if prompter_module:
 | 
			
		||||
            prompt: str = next(
 | 
			
		||||
                prompter_module().build_prompt(instruction=instruction.strip("\n"))
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            prompt = instruction.strip()
 | 
			
		||||
        batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
        print("=" * 40)
 | 
			
		||||
        model.eval()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            generation_config = GenerationConfig(
 | 
			
		||||
                repetition_penalty=1.1,
 | 
			
		||||
                max_new_tokens=1024,
 | 
			
		||||
                temperature=0.9,
 | 
			
		||||
                top_p=0.95,
 | 
			
		||||
                top_k=40,
 | 
			
		||||
                bos_token_id=tokenizer.bos_token_id,
 | 
			
		||||
                eos_token_id=tokenizer.eos_token_id,
 | 
			
		||||
                pad_token_id=tokenizer.pad_token_id,
 | 
			
		||||
                do_sample=True,
 | 
			
		||||
                use_cache=True,
 | 
			
		||||
                return_dict_in_generate=True,
 | 
			
		||||
                output_attentions=False,
 | 
			
		||||
                output_hidden_states=False,
 | 
			
		||||
                output_scores=False,
 | 
			
		||||
            )
 | 
			
		||||
            streamer = TextStreamer(tokenizer)
 | 
			
		||||
            generated = model.generate(
 | 
			
		||||
                inputs=batch["input_ids"].to(cfg.device),
 | 
			
		||||
                generation_config=generation_config,
 | 
			
		||||
                streamer=streamer,
 | 
			
		||||
            )
 | 
			
		||||
        print("=" * 40)
 | 
			
		||||
        print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def choose_config(path: Path):
 | 
			
		||||
    yaml_files = list(path.glob("*.yml"))
 | 
			
		||||
 | 
			
		||||
    if not yaml_files:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "No YAML config files found in the specified directory. Are you using a .yml extension?"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if len(yaml_files) == 1:
 | 
			
		||||
        print(f"Using default YAML file '{yaml_files[0]}'")
 | 
			
		||||
        return yaml_files[0]
 | 
			
		||||
 | 
			
		||||
    print("Choose a YAML file:")
 | 
			
		||||
    for idx, file in enumerate(yaml_files):
 | 
			
		||||
        print(f"{idx + 1}. {file}")
 | 
			
		||||
 | 
			
		||||
    chosen_file = None
 | 
			
		||||
    while chosen_file is None:
 | 
			
		||||
        try:
 | 
			
		||||
            choice = int(input("Enter the number of your choice: "))
 | 
			
		||||
            if 1 <= choice <= len(yaml_files):
 | 
			
		||||
                chosen_file = yaml_files[choice - 1]
 | 
			
		||||
            else:
 | 
			
		||||
                print("Invalid choice. Please choose a number from the list.")
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            print("Invalid input. Please enter a number.")
 | 
			
		||||
 | 
			
		||||
    return chosen_file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
 | 
			
		||||
    return not any(el in list2 for el in list1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
 | 
			
		||||
    if Path(config).is_dir():
 | 
			
		||||
        config = choose_config(config)
 | 
			
		||||
 | 
			
		||||
    # load the config from the yaml file
 | 
			
		||||
    with open(config, encoding="utf-8") as file:
 | 
			
		||||
        cfg: DictDefault = DictDefault(yaml.safe_load(file))
 | 
			
		||||
    # if there are any options passed in the cli, if it is something that seems valid from the yaml,
 | 
			
		||||
    # then overwrite the value
 | 
			
		||||
    cfg_keys = cfg.keys()
 | 
			
		||||
    for k, _ in kwargs.items():
 | 
			
		||||
        # if not strict, allow writing to cfg even if it's not in the yml already
 | 
			
		||||
        if k in cfg_keys or not cfg.strict:
 | 
			
		||||
            # handle booleans
 | 
			
		||||
            if isinstance(cfg[k], bool):
 | 
			
		||||
                cfg[k] = bool(kwargs[k])
 | 
			
		||||
            else:
 | 
			
		||||
                cfg[k] = kwargs[k]
 | 
			
		||||
 | 
			
		||||
    validate_config(cfg)
 | 
			
		||||
 | 
			
		||||
    normalize_config(cfg)
 | 
			
		||||
 | 
			
		||||
    setup_wandb_env_vars(cfg)
 | 
			
		||||
    return cfg
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_datasets(
 | 
			
		||||
    *,
 | 
			
		||||
    cfg: DictDefault,
 | 
			
		||||
    cli_args: TrainerCliArgs,
 | 
			
		||||
) -> TrainDatasetMeta:
 | 
			
		||||
    tokenizer = load_tokenizer(cfg)
 | 
			
		||||
 | 
			
		||||
    train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
 | 
			
		||||
 | 
			
		||||
    if cli_args.debug or cfg.debug:
 | 
			
		||||
        LOG.info("check_dataset_labels...")
 | 
			
		||||
        check_dataset_labels(
 | 
			
		||||
            train_dataset.select(
 | 
			
		||||
                [
 | 
			
		||||
                    random.randrange(0, len(train_dataset) - 1)  # nosec
 | 
			
		||||
                    for _ in range(cli_args.debug_num_examples)
 | 
			
		||||
                ]
 | 
			
		||||
            ),
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            num_examples=cli_args.debug_num_examples,
 | 
			
		||||
            text_only=cli_args.debug_text_only,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return TrainDatasetMeta(
 | 
			
		||||
        train_dataset=train_dataset,
 | 
			
		||||
        eval_dataset=eval_dataset,
 | 
			
		||||
        total_num_steps=total_num_steps,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def do_cli(config: Path = Path("examples/"), **kwargs):
 | 
			
		||||
    print_axolotl_text_art()
 | 
			
		||||
    parsed_cfg = load_cfg(config, **kwargs)
 | 
			
		||||
    parser = transformers.HfArgumentParser((TrainerCliArgs))
 | 
			
		||||
    parsed_cli_args, _ = parser.parse_args_into_dataclasses(
 | 
			
		||||
        return_remaining_strings=True
 | 
			
		||||
    )
 | 
			
		||||
    if parsed_cli_args.inference:
 | 
			
		||||
        do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
 | 
			
		||||
    elif parsed_cli_args.merge_lora:
 | 
			
		||||
        do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
 | 
			
		||||
    elif parsed_cli_args.shard:
 | 
			
		||||
        shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
 | 
			
		||||
    else:
 | 
			
		||||
        dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 | 
			
		||||
        if parsed_cli_args.prepare_ds_only:
 | 
			
		||||
            return
 | 
			
		||||
        train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(do_cli)
 | 
			
		||||
							
								
								
									
										73
									
								
								python/llm/example/GPU/LLM-Finetuning/axolotl/qlora.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								python/llm/example/GPU/LLM-Finetuning/axolotl/qlora.yml
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,73 @@
 | 
			
		|||
base_model: meta-llama/Llama-2-7b-hf
 | 
			
		||||
base_model_config: meta-llama/Llama-2-7b-hf
 | 
			
		||||
model_type: LlamaForCausalLM
 | 
			
		||||
tokenizer_type: LlamaTokenizer
 | 
			
		||||
is_llama_derived_model: true
 | 
			
		||||
 | 
			
		||||
load_in_8bit: false
 | 
			
		||||
load_in_4bit: true
 | 
			
		||||
strict: false
 | 
			
		||||
 | 
			
		||||
datasets:
 | 
			
		||||
  - path: mhenrichsen/alpaca_2k_test
 | 
			
		||||
    type: alpaca
 | 
			
		||||
dataset_prepared_path: last_run_prepared
 | 
			
		||||
val_set_size: 0.01
 | 
			
		||||
output_dir: ./qlora-out
 | 
			
		||||
 | 
			
		||||
adapter: qlora
 | 
			
		||||
lora_model_dir:
 | 
			
		||||
 | 
			
		||||
sequence_len: 4096
 | 
			
		||||
sample_packing: true
 | 
			
		||||
pad_to_sequence_len: true
 | 
			
		||||
 | 
			
		||||
lora_r: 8
 | 
			
		||||
lora_alpha: 16
 | 
			
		||||
lora_dropout: 0.05
 | 
			
		||||
lora_target_modules:
 | 
			
		||||
lora_target_linear: true
 | 
			
		||||
lora_fan_in_fan_out:
 | 
			
		||||
 | 
			
		||||
wandb_project:
 | 
			
		||||
wandb_entity:
 | 
			
		||||
wandb_watch:
 | 
			
		||||
wandb_run_id:
 | 
			
		||||
wandb_log_model:
 | 
			
		||||
 | 
			
		||||
gradient_accumulation_steps: 2
 | 
			
		||||
micro_batch_size: 1
 | 
			
		||||
num_epochs: 3
 | 
			
		||||
# change optimizer from paged_adamw_32bit to adamw_torch
 | 
			
		||||
# due to bitsandbytes issue https://github.com/TimDettmers/bitsandbytes/issues/244
 | 
			
		||||
# optimizer: paged_adamw_32bit
 | 
			
		||||
optimizer: adamw_torch
 | 
			
		||||
lr_scheduler: cosine
 | 
			
		||||
learning_rate: 0.0002
 | 
			
		||||
 | 
			
		||||
train_on_inputs: false
 | 
			
		||||
group_by_length: false
 | 
			
		||||
bf16: true
 | 
			
		||||
fp16: false
 | 
			
		||||
tf32: false
 | 
			
		||||
 | 
			
		||||
gradient_checkpointing: true
 | 
			
		||||
early_stopping_patience:
 | 
			
		||||
resume_from_checkpoint:
 | 
			
		||||
local_rank:
 | 
			
		||||
logging_steps: 1
 | 
			
		||||
xformers_attention:
 | 
			
		||||
flash_attention: false
 | 
			
		||||
 | 
			
		||||
warmup_steps: 10
 | 
			
		||||
eval_steps: 20
 | 
			
		||||
save_steps:
 | 
			
		||||
debug:
 | 
			
		||||
deepspeed:
 | 
			
		||||
weight_decay: 0.0
 | 
			
		||||
fsdp:
 | 
			
		||||
fsdp_config:
 | 
			
		||||
special_tokens:
 | 
			
		||||
  bos_token: "<s>"
 | 
			
		||||
  eos_token: "</s>"
 | 
			
		||||
  unk_token: "<unk>"
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,32 @@
 | 
			
		|||
--extra-index-url https://download.pytorch.org/whl/cu118
 | 
			
		||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
 | 
			
		||||
# torch==2.1.0
 | 
			
		||||
#auto-gptq
 | 
			
		||||
packaging
 | 
			
		||||
peft==0.5.0
 | 
			
		||||
transformers==4.34.0
 | 
			
		||||
bitsandbytes>=0.41.1
 | 
			
		||||
accelerate==0.23.0
 | 
			
		||||
addict
 | 
			
		||||
evaluate
 | 
			
		||||
fire
 | 
			
		||||
PyYAML>=6.0
 | 
			
		||||
datasets
 | 
			
		||||
flash-attn>=2.2.1
 | 
			
		||||
sentencepiece
 | 
			
		||||
wandb
 | 
			
		||||
einops
 | 
			
		||||
#xformers
 | 
			
		||||
optimum
 | 
			
		||||
hf_transfer
 | 
			
		||||
colorama
 | 
			
		||||
numba
 | 
			
		||||
numpy>=1.24.4
 | 
			
		||||
# qlora things
 | 
			
		||||
bert-score==0.3.13
 | 
			
		||||
evaluate==0.4.0
 | 
			
		||||
rouge-score==0.1.2
 | 
			
		||||
scipy
 | 
			
		||||
scikit-learn==1.2.2
 | 
			
		||||
pynvml
 | 
			
		||||
art
 | 
			
		||||
		Loading…
	
		Reference in a new issue