# # Copyright 2016 The BigDL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This file is copied from # https://github.com/OpenAccess-AI-Collective/axolotl/blob/v0.4.0/src/axolotl/cli/train.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ipex_llm import llm_patch llm_patch(train=True) # The following is the original axolotl train code (without IPEX-LLM) """ CLI to run training on a model """ import logging from pathlib import Path from typing import Tuple import fire import transformers from transformers import PreTrainedModel, PreTrainedTokenizer from axolotl.cli import ( check_accelerate_default_config, check_user_token, load_cfg, load_datasets, load_rl_datasets, print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) return do_train(parsed_cfg, parsed_cli_args) def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: print_axolotl_text_art() check_accelerate_default_config() check_user_token() if cfg.rl: dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) if __name__ == "__main__": fire.Fire(do_cli)